diff --git a/lib/Command/AI.php b/lib/Command/AI.php
index e02e7120a..e63198d84 100644
--- a/lib/Command/AI.php
+++ b/lib/Command/AI.php
@@ -23,8 +23,8 @@
namespace OCA\Memories\Command;
-use OC\DB\QueryBuilder\QueryBuilder;
use OCA\Memories\Db\FsManager;
+use OCA\Memories\Db\SQL;
use OCA\Memories\Db\TimelineQuery;
use OCA\Memories\Db\TimelineRoot;
use OCA\Memories\Util;
@@ -43,6 +43,7 @@
const API_IMAGES = '/images';
const API_TEXT = '/text';
+const VECTOR_SIZE = 768;
class AIOpts
{
@@ -85,24 +86,21 @@ public function search(string $prompt): array
$query = $this->connection->getQueryBuilder();
$classlist = array_map(static fn (array $class): int => $class['index'], $response['classes']);
- // $classlist = array_slice($classlist, 0, 1);
+ // $classlist = \array_slice($classlist, 0, 8);
$classQuery = $this->connection->getQueryBuilder();
$classQuery->select('c.word')
->from('memories_ss_class', 'c')
- ->where($query->expr()->andX(
- $query->expr()->eq('c.fileid', 'v.fileid'),
- $query->expr()->orX(
- ...array_map(static fn ($idx) =>
- $query->expr()->eq('c.class', $query->expr()->literal($idx)),
- $classlist)
- ),
- ));
+ ->where($classQuery->expr()->andX(
+ $classQuery->expr()->eq('c.fileid', 'v.fileid'),
+ $classQuery->expr()->in('c.class', array_map(static fn ($idx) => $classQuery->expr()->literal($idx), $classlist)),
+ ))
+ ;
$subquery = $this->connection->getQueryBuilder();
$subquery->select('v.fileid')
->from('memories_ss_vectors', 'v')
- ->where($subquery->createFunction("EXISTS ({$classQuery->getSql()})"))
+ ->where(SQL::exists($query, $classQuery))
->groupBy('v.fileid')
;
@@ -110,66 +108,71 @@ public function search(string $prompt): array
$components = [];
foreach ($response['embedding'] as $i => $value) {
$value = number_format($value, 6);
- $components[] = "v.v{$i}*{$value}";
+ $components[] = "(v.v{$i}*({$value}))";
}
- // Divide the operators into chunks of 96 each
- $sums = array_chunk($components, 96);
+ // Divide the operators into chunks of 48 each
+ $sums = array_chunk($components, 48);
// Add the sum of each chunk
for ($i = 0; $i < \count($sums); ++$i) {
- $sum = implode('+', $sums[$i]);
- $subquery->addSelect($subquery->createFunction("({$sum}) as score{$i}"));
+ $sum = $subquery->createFunction(implode('+', $sums[$i]));
+ $subquery->selectAlias($sum, "score{$i}");
}
// Create outer query
$query->select('sq.fileid')
- ->from($query->createFunction("({$subquery->getSQL()}) sq"))
+ ->from(SQL::subquery($query, $subquery, 'sq'))
;
// Add all score sums together
- $sum = implode('+', array_map(static fn ($_, $i) => "score{$i}", $sums, array_keys($sums)));
- $query->addSelect($query->createFunction("({$sum}) as score"));
+ $finalSum = implode('+', array_map(static fn ($_, $i) => "score{$i}", $sums, array_keys($sums)));
+ $finalSum = $query->createFunction("({$finalSum})");
+ $query->selectAlias($finalSum, 'score');
// Filter for scores less than 1
- // $query->andWhere($query->createFunction("(({$sum}) > 0.04)"));
+ $query = SQL::materialize($query, 'fsq');
+ $query->andWhere($query->expr()->gt('fsq.score', $query->expr()->literal(0.04)));
- $query->orderBy('score', 'DESC');
+ $query->orderBy('fsq.score', 'DESC');
// $query->setMaxResults(8); // batch size
- header('Content-Type: text/html');
+
+ // SQL::debugQuery($query);
$t1 = microtime(true);
+
$res = $query->executeQuery()->fetchAll();
// print length and discard after 10
- echo "
Results: ".\count($res)."
";
- $res = array_slice($res, 0, 10);
+ echo 'Results: '.\count($res).'
';
+ $res = \array_slice($res, 0, 10);
$t2 = microtime(true);
- echo "Search took ".(($t2 - $t1)*1000)." ms
";
- echo "class list: ".json_encode($response['classes'])."
";
+ echo 'Search took '.(($t2 - $t1) * 1000).' ms
';
+ echo 'class list: '.json_encode($response['classes']).'
';
foreach ($res as &$row) {
$fid = $row['fileid'] = (int) $row['fileid'];
$row['score'] = (float) $row['score'];
- $row['score'] = pow(2, $row['score'] * 40);
+ $row['score'] = 2 ** ($row['score'] * 40);
$p = $this->preview->getPreview($this->fs->getUserFile($fid), 1024, 1024);
$data = $p->getContent();
- //get classes for this file
+ // get classes for this file
$q = $this->connection->getQueryBuilder();
$w = $q->select('word')
->from('memories_ss_class', 'c')
->where($q->expr()->eq('c.fileid', $q->createNamedParameter($fid)))
->executeQuery()
- ->fetchAll(\PDO::FETCH_COLUMN);
+ ->fetchAll(\PDO::FETCH_COLUMN)
+ ;
- echo "Score: ". $row['score'] . "
";
- echo "Row: ".json_encode($row)."
";
- echo "Classes: ".json_encode($w)."";
+ echo 'Score: '.$row['score'].'
';
+ echo 'Row: '.json_encode($row).'
';
+ echo 'Classes: '.json_encode($w).'';
echo "";
}
@@ -227,14 +230,18 @@ private function indexUser(IUser $user): void
->from('memories', 'm')
;
- $this->tq->joinFilecache($query, $root, true, false, true);
+ $query = $this->tq->filterFilecache($query, $root, true, false, true);
// Filter by the files that are not indexed by the AI
- $query
- ->leftJoin('m', 'memories_ss_vectors', 'v', $query->expr()->eq('m.fileid', 'v.fileid'))
- ->where($query->expr()->isNull('v.fileid'))
- ->setMaxResults(16) // batch size
+ $vecSq = $this->connection->getQueryBuilder();
+ $vecSq->select($vecSq->expr()->literal(1))
+ ->from('memories_ss_vectors', 'v')
+ ->where($vecSq->expr()->eq('m.fileid', 'v.fileid'))
;
+ $query->andWhere(SQL::notExists($query, $vecSq));
+
+ // Batch size
+ $query->setMaxResults(16);
// FileIds inside this folder that need indexing
$objs = Util::transaction(fn () => $this->tq->executeQueryWithCTEs($query)->fetchAll());
@@ -256,21 +263,23 @@ private function indexSet(Folder $folder, array $objs): void
return;
}
- $count = \count($objs);
- $this->output->writeln("Indexing {$count} files");
-
+ // Get previews for all files
foreach ($objs as &$obj) {
- $obj['fileid'] = (int) $obj['fileid'];
+ $fileid = $obj['fileid'] = (int) $obj['fileid'];
$obj['mtime'] = (int) $obj['mtime'];
try {
// Get file object
$file = $folder->getById($obj['fileid']);
if (empty($file)) {
+ $this->output->writeln("File not found: {$fileid}");
+
continue;
}
$file = $file[0];
if (!$file instanceof File) {
+ $this->output->writeln("Not a file: {$fileid}");
+
continue;
}
@@ -285,13 +294,17 @@ private function indexSet(Folder $folder, array $objs): void
$mime = $preview->getMimeType();
$data = base64_encode($content);
$obj['image'] = "data:{$mime};base64,{$data}";
+
+ // Log
+ $this->output->writeln("Indexing {$file->getPath()}");
} catch (\Exception $e) {
$obj['fileid'] = 0; // mark failure
- $this->output->writeln("Failed to get preview: {$e->getMessage()}".PHP_EOL);
+ $this->output->writeln("Failed to get preview: {$e->getMessage()}");
}
}
// Filter out failed files
+ // TODO: store failure reason
$objs = array_filter($objs, static fn ($obj) => $obj['fileid'] > 0);
// Post to server
@@ -327,11 +340,11 @@ private function indexSet(Folder $folder, array $objs): void
private function ssStoreResult(array $result, int $fileid, int $mtime): void
{
// Check result
- if (768 !== \count($result['embedding'])) {
+ if (VECTOR_SIZE !== \count($result['embedding'])) {
throw new \Exception('Invalid embedding size');
}
- if (\count($result['classes']) === 0) {
+ if (0 === \count($result['classes'])) {
throw new \Exception('No classes returned.');
}
@@ -345,8 +358,8 @@ private function ssStoreResult(array $result, int $fileid, int $mtime): void
];
// Store embedding
- for ($i = 0; $i < \count($result['embedding']); ++$i) {
- $values['v'.$i] = $query->expr()->literal($result['embedding'][$i]);
+ for ($i = 0; $i < VECTOR_SIZE; ++$i) {
+ $values["v{$i}"] = $query->expr()->literal($result['embedding'][$i]);
}
$query->insert('memories_ss_vectors')
diff --git a/lib/Db/SQL.php b/lib/Db/SQL.php
index fbde6917d..9e45f64dd 100644
--- a/lib/Db/SQL.php
+++ b/lib/Db/SQL.php
@@ -68,10 +68,11 @@ public static function materialize(IQueryBuilder $query, string $alias): IQueryB
*
* @param IQueryBuilder $query The query to create the function on
* @param IQueryBuilder $subquery The subquery to use
+ * @param string $alias The alias to use for the subquery
*/
- public static function subquery(IQueryBuilder &$query, IQueryBuilder &$subquery): IQueryFunction
+ public static function subquery(IQueryBuilder &$query, IQueryBuilder &$subquery, string $alias = ''): IQueryFunction
{
- return $query->createFunction("({$subquery->getSQL()})");
+ return $query->createFunction("({$subquery->getSQL()}) {$alias}");
}
/**
diff --git a/lib/Migration/Version800000Date20240327191449.php b/lib/Migration/Version900000Date20240327191449.php
similarity index 72%
rename from lib/Migration/Version800000Date20240327191449.php
rename to lib/Migration/Version900000Date20240327191449.php
index b7c8d343e..ed00d97a2 100644
--- a/lib/Migration/Version800000Date20240327191449.php
+++ b/lib/Migration/Version900000Date20240327191449.php
@@ -28,12 +28,16 @@
use OCP\Migration\IOutput;
use OCP\Migration\SimpleMigrationStep;
-class Version800000Date20240327191449 extends SimpleMigrationStep
+class Version900000Date20240327191449 extends SimpleMigrationStep
{
/**
* @param \Closure(): ISchemaWrapper $schemaClosure
*/
- public function preSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void {}
+ public function preSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void
+ {
+ // Patch doctrine to use float instead of double
+ \Doctrine\DBAL\Types\Type::overrideType(Types::FLOAT, RealFloatType::class);
+ }
/**
* @param \Closure(): ISchemaWrapper $schemaClosure
@@ -58,10 +62,6 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op
'notnull' => true,
'length' => 20,
]);
- $table->addColumn('lsh', Types::INTEGER, [
- 'notnull' => true,
- 'default' => 0,
- ]);
// Create embedding columns
$size = 768;
@@ -74,7 +74,6 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op
$table->setPrimaryKey(['id']);
$table->addIndex(['fileid', 'mtime'], 'memories_ss_vec_fileid');
- $table->addIndex(['lsh'], 'memories_ss_vec_lsh');
}
return $schema;
@@ -83,5 +82,26 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op
/**
* @param \Closure(): ISchemaWrapper $schemaClosure
*/
- public function postSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void {}
+ public function postSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void
+ {
+ // Revert doctrine patch
+ \Doctrine\DBAL\Types\Type::overrideType(Types::FLOAT, \Doctrine\DBAL\Types\FloatType::class);
+ }
+}
+
+class RealFloatType extends \Doctrine\DBAL\Types\FloatType
+{
+ public function getSQLDeclaration(array $column, \Doctrine\DBAL\Platforms\AbstractPlatform $platform)
+ {
+ if (preg_match('/mysql|mariadb/i', $platform::class)) {
+ return 'FLOAT';
+ }
+
+ // https://www.postgresql.org/docs/current/datatype-numeric.html
+ if (preg_match('/postgres/i', $platform::class)) {
+ return 'REAL';
+ }
+
+ return parent::getSQLDeclaration($column, $platform);
+ }
}
diff --git a/lib/Migration/Version800000Date20240327192949.php b/lib/Migration/Version900000Date20240327192949.php
similarity index 92%
rename from lib/Migration/Version800000Date20240327192949.php
rename to lib/Migration/Version900000Date20240327192949.php
index 20cf6a105..5fba29a7d 100644
--- a/lib/Migration/Version800000Date20240327192949.php
+++ b/lib/Migration/Version900000Date20240327192949.php
@@ -28,7 +28,7 @@
use OCP\Migration\IOutput;
use OCP\Migration\SimpleMigrationStep;
-class Version800000Date20240327192949 extends SimpleMigrationStep
+class Version900000Date20240327192949 extends SimpleMigrationStep
{
/**
* @param \Closure(): ISchemaWrapper $schemaClosure
@@ -61,6 +61,10 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op
'notnull' => true,
'default' => 0,
]);
+ $table->addColumn('word', Types::STRING, [
+ 'notnull' => false,
+ 'length' => 64,
+ ]);
$table->setPrimaryKey(['id']);
$table->addIndex(['fileid'], 'memories_ss_cls_fileid');