Skip to content

Commit b4e30f5

Browse files
committed
[Store][Postgres] allow store initialization with utilized distance
1 parent fb19713 commit b4e30f5

File tree

3 files changed

+190
-13
lines changed

3 files changed

+190
-13
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Store\Bridge\Postgres;
13+
14+
use OskarStark\Enum\Trait\Comparable;
15+
16+
/**
17+
* @author Denis Zunke <[email protected]>
18+
*/
19+
enum Distance: string
20+
{
21+
use Comparable;
22+
23+
case Cosine = 'cosine';
24+
case InnerProduct = 'inner_product';
25+
case L1 = 'l1';
26+
case L2 = 'l2';
27+
28+
public function getComparisonSign(): string
29+
{
30+
return match ($this) {
31+
self::Cosine => '<=>',
32+
self::InnerProduct => '<#>',
33+
self::L1 => '<+>',
34+
self::L2 => '<->',
35+
};
36+
}
37+
}

src/store/src/Bridge/Postgres/Store.php

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,32 @@ public function __construct(
3434
private \PDO $connection,
3535
private string $tableName,
3636
private string $vectorFieldName = 'embedding',
37+
private Distance $distance = Distance::L2,
3738
) {
3839
}
3940

40-
public static function fromPdo(\PDO $connection, string $tableName, string $vectorFieldName = 'embedding'): self
41-
{
42-
return new self($connection, $tableName, $vectorFieldName);
41+
public static function fromPdo(
42+
\PDO $connection,
43+
string $tableName,
44+
string $vectorFieldName = 'embedding',
45+
Distance $distance = Distance::L2,
46+
): self {
47+
return new self($connection, $tableName, $vectorFieldName, $distance);
4348
}
4449

45-
public static function fromDbal(Connection $connection, string $tableName, string $vectorFieldName = 'embedding'): self
46-
{
50+
public static function fromDbal(
51+
Connection $connection,
52+
string $tableName,
53+
string $vectorFieldName = 'embedding',
54+
Distance $distance = Distance::L2,
55+
): self {
4756
$pdo = $connection->getNativeConnection();
4857

4958
if (!$pdo instanceof \PDO) {
5059
throw new InvalidArgumentException('Only DBAL connections using PDO driver are supported.');
5160
}
5261

53-
return self::fromPdo($pdo, $tableName, $vectorFieldName);
62+
return self::fromPdo($pdo, $tableName, $vectorFieldName, $distance);
5463
}
5564

5665
public function add(VectorDocument ...$documents): void
@@ -84,16 +93,18 @@ public function add(VectorDocument ...$documents): void
8493
*/
8594
public function query(Vector $vector, array $options = [], ?float $minScore = null): array
8695
{
87-
$sql = \sprintf(
88-
'SELECT id, %s AS embedding, metadata, (%s <-> :embedding) AS score
89-
FROM %s
90-
%s
91-
ORDER BY score ASC
92-
LIMIT %d',
96+
$sql = \sprintf(<<<SQL
97+
SELECT id, %s AS embedding, metadata, (%s %s :embedding) AS score
98+
FROM %s
99+
%s
100+
ORDER BY score ASC
101+
LIMIT %d
102+
SQL,
93103
$this->vectorFieldName,
94104
$this->vectorFieldName,
105+
$this->distance->getComparisonSign(),
95106
$this->tableName,
96-
null !== $minScore ? "WHERE ({$this->vectorFieldName} <-> :embedding) >= :minScore" : '',
107+
null !== $minScore ? "WHERE ({$this->vectorFieldName} {$this->distance->getComparisonSign()} :embedding) >= :minScore" : '',
97108
$options['limit'] ?? 5,
98109
);
99110
$statement = $this->connection->prepare($sql);
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Store\Tests\Bridge\Postgres;
13+
14+
use PHPUnit\Framework\Attributes\CoversClass;
15+
use PHPUnit\Framework\Attributes\Test;
16+
use PHPUnit\Framework\TestCase;
17+
use Symfony\AI\Platform\Vector\Vector;
18+
use Symfony\AI\Store\Bridge\Postgres\Distance;
19+
use Symfony\AI\Store\Bridge\Postgres\Store;
20+
use Symfony\AI\Store\Document\VectorDocument;
21+
use Symfony\Component\Uid\Uuid;
22+
23+
#[CoversClass(Store::class)]
24+
final class StoreTest extends TestCase
25+
{
26+
#[Test]
27+
public function queryWithMinScore(): void
28+
{
29+
$uuid = Uuid::v4();
30+
$vectorData = [0.1, 0.2, 0.3];
31+
$minScore = 0.8;
32+
$pdo = $this->createMock(\PDO::class);
33+
34+
$pdo->expects(self::once())
35+
->method('prepare')
36+
->with(<<<SQL
37+
SELECT id, embedding_index AS embedding, metadata, (embedding_index <=> :embedding) AS score
38+
FROM embeddings_table
39+
WHERE (embedding_index <=> :embedding) >= :minScore
40+
ORDER BY score ASC
41+
LIMIT 5
42+
SQL)
43+
->willReturn($statement = $this->createMock(\PDOStatement::class));
44+
45+
$statement->expects($this->once())
46+
->method('execute')
47+
->with([
48+
'embedding' => json_encode($vectorData),
49+
'minScore' => $minScore,
50+
]);
51+
52+
$statement->expects($this->once())
53+
->method('fetchAll')
54+
->with(\PDO::FETCH_ASSOC)
55+
->willReturn([
56+
[
57+
'id' => $uuid->toBinary(),
58+
'embedding' => json_encode($vectorData),
59+
'metadata' => json_encode(['title' => 'Test Document']),
60+
'score' => 0.85,
61+
],
62+
]);
63+
64+
$store = new Store(
65+
$pdo,
66+
'embeddings_table',
67+
'embedding_index',
68+
Distance::Cosine,
69+
);
70+
71+
$results = $store->query(new Vector($vectorData), [], $minScore);
72+
73+
self::assertCount(1, $results);
74+
self::assertInstanceOf(VectorDocument::class, $results[0]);
75+
self::assertSame(0.85, $results[0]->score);
76+
self::assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
77+
}
78+
79+
#[Test]
80+
public function queryWithoutMinScore(): void
81+
{
82+
$uuid = Uuid::v4();
83+
$vectorData = [0.1, 0.2, 0.3];
84+
$pdo = $this->createMock(\PDO::class);
85+
86+
$pdo->expects(self::once())
87+
->method('prepare')
88+
->with(<<<SQL
89+
SELECT id, embedding_index AS embedding, metadata, (embedding_index <=> :embedding) AS score
90+
FROM embeddings_table
91+
92+
ORDER BY score ASC
93+
LIMIT 5
94+
SQL)
95+
->willReturn($statement = $this->createMock(\PDOStatement::class));
96+
97+
$statement->expects($this->once())
98+
->method('execute')
99+
->with([
100+
'embedding' => json_encode($vectorData),
101+
]);
102+
103+
$statement->expects($this->once())
104+
->method('fetchAll')
105+
->with(\PDO::FETCH_ASSOC)
106+
->willReturn([
107+
[
108+
'id' => $uuid->toBinary(),
109+
'embedding' => json_encode($vectorData),
110+
'metadata' => json_encode(['title' => 'Test Document']),
111+
'score' => 0.85,
112+
],
113+
]);
114+
115+
$store = new Store(
116+
$pdo,
117+
'embeddings_table',
118+
'embedding_index',
119+
Distance::Cosine,
120+
);
121+
122+
$results = $store->query(new Vector($vectorData));
123+
124+
self::assertCount(1, $results);
125+
self::assertInstanceOf(VectorDocument::class, $results[0]);
126+
self::assertSame(0.85, $results[0]->score);
127+
self::assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
128+
}
129+
}

0 commit comments

Comments
 (0)