Skip to content

Commit d1b07e7

Browse files
Avital-Finechayim
andauthored
Support Vector Similarity (#151)
* Support Vector Similarity Co-authored-by: Chayim I. Kirshen <[email protected]>
1 parent 5c6afe8 commit d1b07e7

File tree

6 files changed

+156
-24
lines changed

6 files changed

+156
-24
lines changed

.circleci/config.yml

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
build: # test with redisearch:latest
5454
docker:
5555
- image: circleci/golang:1.16
56-
- image: redislabs/redisearch:latest
56+
- image: redislabs/redisearch:edge
5757

5858
working_directory: /go/src/github.com/RediSearch/redisearch-go
5959
steps:
@@ -65,34 +65,12 @@ jobs:
6565
- run: make coverage
6666
- run: bash <(curl -s https://raw.githubusercontent.com/codecov/codecov-bash/master/codecov) -t ${CODECOV_TOKEN}
6767

68-
build-v16:
69-
docker:
70-
- image: circleci/golang:1.16
71-
- image: redislabs/redisearch:1.6.15
72-
73-
working_directory: /go/src/github.com/RediSearch/redisearch-go
74-
steps:
75-
- checkout
76-
- run: make test
77-
78-
build-edge: # test nightly with redisearch:edge
79-
docker:
80-
- image: circleci/golang:1.16
81-
- image: redislabs/redisearch:edge
82-
83-
working_directory: /go/src/github.com/RediSearch/redisearch-go
84-
steps:
85-
- checkout
86-
- run: make test
87-
8868
workflows:
8969
version: 2
9070
commit:
9171
jobs:
9272
- build-tls
9373
- build
94-
- build-edge
95-
- build-v16
9674
nightly:
9775
triggers:
9876
- schedule:
@@ -102,5 +80,5 @@ workflows:
10280
only:
10381
- master
10482
jobs:
105-
- build-edge
83+
- build
10684
- build-tls

redisearch/query.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ type Query struct {
9393
SortBy *SortingKey
9494
HighlightOpts *HighlightOptions
9595
SummarizeOpts *SummaryOptions
96+
Params map[string]interface{}
97+
Dialect int
9698
}
9799

98100
// Paging represents the offset paging of a search result
@@ -227,6 +229,18 @@ func (q Query) serialize() redis.Args {
227229
}
228230
}
229231
}
232+
233+
if q.Params != nil {
234+
args = args.Add("PARAMS", len(q.Params)*2)
235+
for name, value := range q.Params {
236+
args = args.Add(name, value)
237+
}
238+
}
239+
240+
if q.Dialect != 0 {
241+
args = args.Add("DIALECT", q.Dialect)
242+
}
243+
230244
return args
231245
}
232246

@@ -370,6 +384,29 @@ func (q *Query) SummarizeOptions(opts SummaryOptions) *Query {
370384
return q
371385
}
372386

387+
// SetParams sets parameters that can be referenced in the query string by a $ , followed by the parameter name,
388+
// e.g., $user , and each such reference in the search query to a parameter name is substituted
389+
// by the corresponding parameter value.
390+
func (q *Query) SetParams(params map[string]interface{}) *Query {
391+
q.Params = params
392+
return q
393+
}
394+
395+
// AddParam adds a new param to the parameters list
396+
func (q *Query) AddParam(name string, value interface{}) *Query {
397+
if q.Params == nil {
398+
q.Params = make(map[string]interface{})
399+
}
400+
q.Params[name] = value
401+
return q
402+
}
403+
404+
// SetDialect can have one of 2 options: 1 or 2
405+
func (q *Query) SetDialect(dialect int) *Query {
406+
q.Dialect = dialect
407+
return q
408+
}
409+
373410
// IndexOptions indexes multiple documents on the index, with optional Options passed to options
374411
func (i *Client) IndexOptions(opts IndexingOptions, docs ...Document) error {
375412

redisearch/query_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ func TestQuery_serialize(t *testing.T) {
7979
SortBy *SortingKey
8080
HighlightOpts *HighlightOptions
8181
SummarizeOpts *SummaryOptions
82+
Params map[string]interface{}
83+
Dialect int
8284
}
8385
tests := []struct {
8486
name string
@@ -111,6 +113,8 @@ func TestQuery_serialize(t *testing.T) {
111113
NumFragments: 3,
112114
Separator: "...",
113115
}}, redis.Args{raw, "LIMIT", 0, 0, "SUMMARIZE", "FIELDS", 1, "test_field", "LEN", 20, "FRAGS", 3, "SEPARATOR", "..."}},
116+
{"Params", fields{Raw: raw, Params: map[string]interface{}{"min": 1}}, redis.Args{raw, "LIMIT", 0, 0, "PARAMS", 2, "min", 1}},
117+
{"Dialect", fields{Raw: raw, Dialect: 2}, redis.Args{raw, "LIMIT", 0, 0, "DIALECT", 2}},
114118
}
115119
for _, tt := range tests {
116120
t.Run(tt.name, func(t *testing.T) {
@@ -126,6 +130,8 @@ func TestQuery_serialize(t *testing.T) {
126130
SortBy: tt.fields.SortBy,
127131
HighlightOpts: tt.fields.HighlightOpts,
128132
SummarizeOpts: tt.fields.SummarizeOpts,
133+
Params: tt.fields.Params,
134+
Dialect: tt.fields.Dialect,
129135
}
130136
if g := q.serialize(); !reflect.DeepEqual(g, tt.want) {
131137
t.Errorf("serialize() = %v, want %v", g, tt.want)

redisearch/redisearch_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,67 @@ func TestReturnFields(t *testing.T) {
510510
assert.Equal(t, "Jon", docs[0].Properties["name"])
511511
assert.Equal(t, "25", docs[0].Properties["years"])
512512
}
513+
514+
func TestParams(t *testing.T) {
515+
c := createClient("TestParams")
516+
version, _ := c.getRediSearchVersion()
517+
if version < 20430 {
518+
// VectorSimilarity is available for RediSearch 2.2+
519+
return
520+
}
521+
522+
// Create a schema
523+
sc := NewSchema(DefaultOptions).AddField(NewNumericField("numval"))
524+
c.Drop()
525+
assert.Nil(t, c.CreateIndex(sc))
526+
// Create data
527+
_, err := c.pool.Get().Do("HSET", "1", "numval", "1")
528+
assert.Nil(t, err)
529+
_, err = c.pool.Get().Do("HSET", "2", "numval", "2")
530+
assert.Nil(t, err)
531+
_, err = c.pool.Get().Do("HSET", "3", "numval", "3")
532+
assert.Nil(t, err)
533+
// Searching with parameters
534+
_, total, err := c.Search(NewQuery("@numval:[$min $max]").
535+
SetParams(map[string]interface{}{"min": "1", "max": "2"}).
536+
SetDialect(2))
537+
assert.Nil(t, err)
538+
assert.Equal(t, 2, total)
539+
}
540+
541+
func TestVectorField(t *testing.T) {
542+
c := createClient("TestVectorField")
543+
version, _ := c.getRediSearchVersion()
544+
if version < 20430 {
545+
// VectorSimilarity is available for RediSearch 2.2+
546+
return
547+
}
548+
549+
// Create a schema
550+
sc := NewSchema(DefaultOptions).AddField(
551+
NewVectorFieldOptions("v", VectorFieldOptions{Algorithm: Flat, Attributes: map[string]interface{}{
552+
"TYPE": "FLOAT32",
553+
"DIM": 2,
554+
"DISTANCE_METRIC": "L2",
555+
}}),
556+
)
557+
c.Drop()
558+
assert.Nil(t, c.CreateIndex(sc))
559+
// Create data
560+
_, err := c.pool.Get().Do("HSET", "a", "v", "aaaaaaaa")
561+
assert.Nil(t, err)
562+
_, err = c.pool.Get().Do("HSET", "b", "v", "aaaabaaa")
563+
assert.Nil(t, err)
564+
_, err = c.pool.Get().Do("HSET", "c", "v", "aaaaabaa")
565+
assert.Nil(t, err)
566+
// Searching with parameters
567+
docs, total, err := c.Search(NewQuery("*=>[KNN 2 @v $vec]").
568+
AddParam("vec", "aaaaaaaa").
569+
SetSortBy("__v_score", true).
570+
AddReturnFields("__v_score").
571+
SetDialect(2))
572+
assert.Nil(t, err)
573+
assert.Equal(t, 2, total)
574+
assert.Equal(t, "a", docs[0].Id)
575+
assert.Equal(t, "0", docs[0].Properties["__v_score"])
576+
}

redisearch/schema.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ const (
135135

136136
// TagField is a field used for compact indexing of comma separated values
137137
TagField
138+
139+
//VectorField allows vector similarity queries against the value in this attribute.
140+
VectorField
138141
)
139142

140143
// Phonetic Matchers
@@ -185,6 +188,20 @@ type GeoFieldOptions struct {
185188
As string
186189
}
187190

191+
type algorithm string
192+
193+
// Supported algorithms for Vector field
194+
const (
195+
Flat algorithm = "FLAT"
196+
HNSW algorithm = "HNSW"
197+
)
198+
199+
// VectorFieldOptions Options for vector fields
200+
type VectorFieldOptions struct {
201+
Algorithm algorithm
202+
Attributes map[string]interface{}
203+
}
204+
188205
// NewTextField creates a new text field with the given weight
189206
func NewTextField(name string) Field {
190207
return Field{
@@ -268,6 +285,15 @@ func NewGeoFieldOptions(name string, options GeoFieldOptions) Field {
268285
return f
269286
}
270287

288+
// NewVectorFieldOptions creates a new geo field with the given name and additional options
289+
func NewVectorFieldOptions(name string, options VectorFieldOptions) Field {
290+
return Field{
291+
Name: name,
292+
Type: VectorField,
293+
Options: options,
294+
}
295+
}
296+
271297
// Schema represents an index schema Schema, or how the index would
272298
// treat documents sent to it.
273299
type Schema struct {
@@ -417,6 +443,26 @@ func serializeField(f Field, args redis.Args) (argsOut redis.Args, err error) {
417443
argsOut = append(argsOut, "NOINDEX")
418444
}
419445
}
446+
case VectorField:
447+
argsOut = append(argsOut, f.Name, "VECTOR")
448+
if f.Options != nil {
449+
opts, ok := f.Options.(VectorFieldOptions)
450+
if !ok {
451+
err = fmt.Errorf("Error on VectorField serialization")
452+
return
453+
}
454+
if opts.Algorithm != "" {
455+
argsOut = append(argsOut, opts.Algorithm)
456+
}
457+
if opts.Attributes != nil {
458+
var flat []interface{}
459+
for attrName, attrValue := range opts.Attributes {
460+
flat = append(flat, attrName, attrValue)
461+
}
462+
argsOut = append(argsOut, len(flat))
463+
argsOut = append(argsOut, flat...)
464+
}
465+
}
420466
default:
421467
err = fmt.Errorf("Unrecognized field type %v serialization", f.Type)
422468
return

redisearch/schema_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ func TestSerializeSchema(t *testing.T) {
7474
{"default-and-tag", args{NewSchema(DefaultOptions).AddField(NewTagField("tag-field")), redis.Args{}}, redis.Args{"SCHEMA", "tag-field", "TAG", "SEPARATOR", ","}, false},
7575
{"default-and-tag-with-options", args{NewSchema(DefaultOptions).AddField(NewTagFieldOptions("tag-field", TagFieldOptions{Sortable: true, NoIndex: false, Separator: byte(','), As: "field"})), redis.Args{}}, redis.Args{"SCHEMA", "tag-field", "AS", "field", "TAG", "SEPARATOR", ",", "SORTABLE"}, false},
7676
{"default-geo-with-options", args{NewSchema(DefaultOptions).AddField(NewGeoFieldOptions("location", GeoFieldOptions{As: "loc"})), redis.Args{}}, redis.Args{"SCHEMA", "location", "AS", "loc", "GEO"}, false},
77+
{"default-vector", args{NewSchema(DefaultOptions).AddField(NewVectorFieldOptions("vec", VectorFieldOptions{Algorithm: Flat, Attributes: map[string]interface{}{"DIM": 128}})), redis.Args{}}, redis.Args{"SCHEMA", "vec", "VECTOR", Flat, 2, "DIM", 128}, false},
7778
{"error-unsupported", args{NewSchema(DefaultOptions).AddField(Field{Type: 10}), redis.Args{}}, nil, true},
7879
}
7980
for _, tt := range tests {

0 commit comments

Comments
 (0)