diff --git a/lib/Command/AI.php b/lib/Command/AI.php index e02e7120a..e63198d84 100644 --- a/lib/Command/AI.php +++ b/lib/Command/AI.php @@ -23,8 +23,8 @@ namespace OCA\Memories\Command; -use OC\DB\QueryBuilder\QueryBuilder; use OCA\Memories\Db\FsManager; +use OCA\Memories\Db\SQL; use OCA\Memories\Db\TimelineQuery; use OCA\Memories\Db\TimelineRoot; use OCA\Memories\Util; @@ -43,6 +43,7 @@ const API_IMAGES = '/images'; const API_TEXT = '/text'; +const VECTOR_SIZE = 768; class AIOpts { @@ -85,24 +86,21 @@ public function search(string $prompt): array $query = $this->connection->getQueryBuilder(); $classlist = array_map(static fn (array $class): int => $class['index'], $response['classes']); - // $classlist = array_slice($classlist, 0, 1); + // $classlist = \array_slice($classlist, 0, 8); $classQuery = $this->connection->getQueryBuilder(); $classQuery->select('c.word') ->from('memories_ss_class', 'c') - ->where($query->expr()->andX( - $query->expr()->eq('c.fileid', 'v.fileid'), - $query->expr()->orX( - ...array_map(static fn ($idx) => - $query->expr()->eq('c.class', $query->expr()->literal($idx)), - $classlist) - ), - )); + ->where($classQuery->expr()->andX( + $classQuery->expr()->eq('c.fileid', 'v.fileid'), + $classQuery->expr()->in('c.class', array_map(static fn ($idx) => $classQuery->expr()->literal($idx), $classlist)), + )) + ; $subquery = $this->connection->getQueryBuilder(); $subquery->select('v.fileid') ->from('memories_ss_vectors', 'v') - ->where($subquery->createFunction("EXISTS ({$classQuery->getSql()})")) + ->where(SQL::exists($query, $classQuery)) ->groupBy('v.fileid') ; @@ -110,66 +108,71 @@ public function search(string $prompt): array $components = []; foreach ($response['embedding'] as $i => $value) { $value = number_format($value, 6); - $components[] = "v.v{$i}*{$value}"; + $components[] = "(v.v{$i}*({$value}))"; } - // Divide the operators into chunks of 96 each - $sums = array_chunk($components, 96); + // Divide the operators into chunks of 48 each + $sums = array_chunk($components, 48); // Add the sum of each chunk for ($i = 0; $i < \count($sums); ++$i) { - $sum = implode('+', $sums[$i]); - $subquery->addSelect($subquery->createFunction("({$sum}) as score{$i}")); + $sum = $subquery->createFunction(implode('+', $sums[$i])); + $subquery->selectAlias($sum, "score{$i}"); } // Create outer query $query->select('sq.fileid') - ->from($query->createFunction("({$subquery->getSQL()}) sq")) + ->from(SQL::subquery($query, $subquery, 'sq')) ; // Add all score sums together - $sum = implode('+', array_map(static fn ($_, $i) => "score{$i}", $sums, array_keys($sums))); - $query->addSelect($query->createFunction("({$sum}) as score")); + $finalSum = implode('+', array_map(static fn ($_, $i) => "score{$i}", $sums, array_keys($sums))); + $finalSum = $query->createFunction("({$finalSum})"); + $query->selectAlias($finalSum, 'score'); // Filter for scores less than 1 - // $query->andWhere($query->createFunction("(({$sum}) > 0.04)")); + $query = SQL::materialize($query, 'fsq'); + $query->andWhere($query->expr()->gt('fsq.score', $query->expr()->literal(0.04))); - $query->orderBy('score', 'DESC'); + $query->orderBy('fsq.score', 'DESC'); // $query->setMaxResults(8); // batch size - header('Content-Type: text/html'); + + // SQL::debugQuery($query); $t1 = microtime(true); + $res = $query->executeQuery()->fetchAll(); // print length and discard after 10 - echo "

Results: ".\count($res)."

"; - $res = array_slice($res, 0, 10); + echo '

Results: '.\count($res).'

'; + $res = \array_slice($res, 0, 10); $t2 = microtime(true); - echo "

Search took ".(($t2 - $t1)*1000)." ms

"; - echo "class list: ".json_encode($response['classes'])."
"; + echo '

Search took '.(($t2 - $t1) * 1000).' ms

'; + echo 'class list: '.json_encode($response['classes']).'
'; foreach ($res as &$row) { $fid = $row['fileid'] = (int) $row['fileid']; $row['score'] = (float) $row['score']; - $row['score'] = pow(2, $row['score'] * 40); + $row['score'] = 2 ** ($row['score'] * 40); $p = $this->preview->getPreview($this->fs->getUserFile($fid), 1024, 1024); $data = $p->getContent(); - //get classes for this file + // get classes for this file $q = $this->connection->getQueryBuilder(); $w = $q->select('word') ->from('memories_ss_class', 'c') ->where($q->expr()->eq('c.fileid', $q->createNamedParameter($fid))) ->executeQuery() - ->fetchAll(\PDO::FETCH_COLUMN); + ->fetchAll(\PDO::FETCH_COLUMN) + ; - echo "

Score: ". $row['score'] . "

"; - echo "Row: ".json_encode($row)."
"; - echo "Classes: ".json_encode($w)."
"; + echo '

Score: '.$row['score'].'

'; + echo 'Row: '.json_encode($row).'
'; + echo 'Classes: '.json_encode($w).'
'; echo ""; } @@ -227,14 +230,18 @@ private function indexUser(IUser $user): void ->from('memories', 'm') ; - $this->tq->joinFilecache($query, $root, true, false, true); + $query = $this->tq->filterFilecache($query, $root, true, false, true); // Filter by the files that are not indexed by the AI - $query - ->leftJoin('m', 'memories_ss_vectors', 'v', $query->expr()->eq('m.fileid', 'v.fileid')) - ->where($query->expr()->isNull('v.fileid')) - ->setMaxResults(16) // batch size + $vecSq = $this->connection->getQueryBuilder(); + $vecSq->select($vecSq->expr()->literal(1)) + ->from('memories_ss_vectors', 'v') + ->where($vecSq->expr()->eq('m.fileid', 'v.fileid')) ; + $query->andWhere(SQL::notExists($query, $vecSq)); + + // Batch size + $query->setMaxResults(16); // FileIds inside this folder that need indexing $objs = Util::transaction(fn () => $this->tq->executeQueryWithCTEs($query)->fetchAll()); @@ -256,21 +263,23 @@ private function indexSet(Folder $folder, array $objs): void return; } - $count = \count($objs); - $this->output->writeln("Indexing {$count} files"); - + // Get previews for all files foreach ($objs as &$obj) { - $obj['fileid'] = (int) $obj['fileid']; + $fileid = $obj['fileid'] = (int) $obj['fileid']; $obj['mtime'] = (int) $obj['mtime']; try { // Get file object $file = $folder->getById($obj['fileid']); if (empty($file)) { + $this->output->writeln("File not found: {$fileid}"); + continue; } $file = $file[0]; if (!$file instanceof File) { + $this->output->writeln("Not a file: {$fileid}"); + continue; } @@ -285,13 +294,17 @@ private function indexSet(Folder $folder, array $objs): void $mime = $preview->getMimeType(); $data = base64_encode($content); $obj['image'] = "data:{$mime};base64,{$data}"; + + // Log + $this->output->writeln("Indexing {$file->getPath()}"); } catch (\Exception $e) { $obj['fileid'] = 0; // mark failure - $this->output->writeln("Failed to get preview: {$e->getMessage()}".PHP_EOL); + $this->output->writeln("Failed to get preview: {$e->getMessage()}"); } } // Filter out failed files + // TODO: store failure reason $objs = array_filter($objs, static fn ($obj) => $obj['fileid'] > 0); // Post to server @@ -327,11 +340,11 @@ private function indexSet(Folder $folder, array $objs): void private function ssStoreResult(array $result, int $fileid, int $mtime): void { // Check result - if (768 !== \count($result['embedding'])) { + if (VECTOR_SIZE !== \count($result['embedding'])) { throw new \Exception('Invalid embedding size'); } - if (\count($result['classes']) === 0) { + if (0 === \count($result['classes'])) { throw new \Exception('No classes returned.'); } @@ -345,8 +358,8 @@ private function ssStoreResult(array $result, int $fileid, int $mtime): void ]; // Store embedding - for ($i = 0; $i < \count($result['embedding']); ++$i) { - $values['v'.$i] = $query->expr()->literal($result['embedding'][$i]); + for ($i = 0; $i < VECTOR_SIZE; ++$i) { + $values["v{$i}"] = $query->expr()->literal($result['embedding'][$i]); } $query->insert('memories_ss_vectors') diff --git a/lib/Db/SQL.php b/lib/Db/SQL.php index fbde6917d..9e45f64dd 100644 --- a/lib/Db/SQL.php +++ b/lib/Db/SQL.php @@ -68,10 +68,11 @@ public static function materialize(IQueryBuilder $query, string $alias): IQueryB * * @param IQueryBuilder $query The query to create the function on * @param IQueryBuilder $subquery The subquery to use + * @param string $alias The alias to use for the subquery */ - public static function subquery(IQueryBuilder &$query, IQueryBuilder &$subquery): IQueryFunction + public static function subquery(IQueryBuilder &$query, IQueryBuilder &$subquery, string $alias = ''): IQueryFunction { - return $query->createFunction("({$subquery->getSQL()})"); + return $query->createFunction("({$subquery->getSQL()}) {$alias}"); } /** diff --git a/lib/Migration/Version800000Date20240327191449.php b/lib/Migration/Version900000Date20240327191449.php similarity index 72% rename from lib/Migration/Version800000Date20240327191449.php rename to lib/Migration/Version900000Date20240327191449.php index b7c8d343e..ed00d97a2 100644 --- a/lib/Migration/Version800000Date20240327191449.php +++ b/lib/Migration/Version900000Date20240327191449.php @@ -28,12 +28,16 @@ use OCP\Migration\IOutput; use OCP\Migration\SimpleMigrationStep; -class Version800000Date20240327191449 extends SimpleMigrationStep +class Version900000Date20240327191449 extends SimpleMigrationStep { /** * @param \Closure(): ISchemaWrapper $schemaClosure */ - public function preSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void {} + public function preSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void + { + // Patch doctrine to use float instead of double + \Doctrine\DBAL\Types\Type::overrideType(Types::FLOAT, RealFloatType::class); + } /** * @param \Closure(): ISchemaWrapper $schemaClosure @@ -58,10 +62,6 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op 'notnull' => true, 'length' => 20, ]); - $table->addColumn('lsh', Types::INTEGER, [ - 'notnull' => true, - 'default' => 0, - ]); // Create embedding columns $size = 768; @@ -74,7 +74,6 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op $table->setPrimaryKey(['id']); $table->addIndex(['fileid', 'mtime'], 'memories_ss_vec_fileid'); - $table->addIndex(['lsh'], 'memories_ss_vec_lsh'); } return $schema; @@ -83,5 +82,26 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op /** * @param \Closure(): ISchemaWrapper $schemaClosure */ - public function postSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void {} + public function postSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void + { + // Revert doctrine patch + \Doctrine\DBAL\Types\Type::overrideType(Types::FLOAT, \Doctrine\DBAL\Types\FloatType::class); + } +} + +class RealFloatType extends \Doctrine\DBAL\Types\FloatType +{ + public function getSQLDeclaration(array $column, \Doctrine\DBAL\Platforms\AbstractPlatform $platform) + { + if (preg_match('/mysql|mariadb/i', $platform::class)) { + return 'FLOAT'; + } + + // https://www.postgresql.org/docs/current/datatype-numeric.html + if (preg_match('/postgres/i', $platform::class)) { + return 'REAL'; + } + + return parent::getSQLDeclaration($column, $platform); + } } diff --git a/lib/Migration/Version800000Date20240327192949.php b/lib/Migration/Version900000Date20240327192949.php similarity index 92% rename from lib/Migration/Version800000Date20240327192949.php rename to lib/Migration/Version900000Date20240327192949.php index 20cf6a105..5fba29a7d 100644 --- a/lib/Migration/Version800000Date20240327192949.php +++ b/lib/Migration/Version900000Date20240327192949.php @@ -28,7 +28,7 @@ use OCP\Migration\IOutput; use OCP\Migration\SimpleMigrationStep; -class Version800000Date20240327192949 extends SimpleMigrationStep +class Version900000Date20240327192949 extends SimpleMigrationStep { /** * @param \Closure(): ISchemaWrapper $schemaClosure @@ -61,6 +61,10 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op 'notnull' => true, 'default' => 0, ]); + $table->addColumn('word', Types::STRING, [ + 'notnull' => false, + 'length' => 64, + ]); $table->setPrimaryKey(['id']); $table->addIndex(['fileid'], 'memories_ss_cls_fileid');