diff --git a/README.md b/README.md index b372944..6cb6cfe 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,12 @@ AI.MODELSET | [ModelSet](https://godoc.org/github.com/RedisAI/redisai-go/redisai AI.MODELGET | [ModelGet](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ModelGet) and [ModelGetToModel](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ModelGetToModel) AI.MODELDEL | [ModelDel](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ModelDel) AI.MODELRUN | [ModelRun](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ModelRun) -AI._MODELSCAN | +AI._MODELSCAN | [ModelScan](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ModelScan) AI.SCRIPTSET | [ScriptSet](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ScriptSet) AI.SCRIPTGET | [ScriptGet](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ScriptGet) -AI.SCRIPTDEL | [ScriptDel](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ScriptRun) -AI.SCRIPTRUN | [ScriptRun](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ScriptDel) -AI._SCRIPTSCAN | +AI.SCRIPTDEL | [ScriptDel](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ScriptDel) +AI.SCRIPTRUN | [ScriptRun](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ScriptRun) +AI._SCRIPTSCAN | [ScriptScan](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.ScriptScan) AI.DAGRUN | [DagRun](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.DagRun) AI.DAGRUN_RO | [DagRunRO](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.DagRunRO) AI.INFO | [Info](https://godoc.org/github.com/RedisAI/redisai-go/redisai#Client.Info) diff --git a/redisai/commands.go b/redisai/commands.go index 419c115..a69a829 100644 --- a/redisai/commands.go +++ b/redisai/commands.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/gomodule/redigo/redis" + "log" "strconv" ) @@ -237,3 +238,34 @@ func AddDagRunArgs(loadKeys []string, persistKeys []string, commandArgs redis.Ar } return args } + +// Returns all the models in the database. +func (c *Client) ModelScan() ([][]string, error) { + reply, err := c.DoOrSend("AI._MODELSCAN", redis.Args{}, nil) + return ParseScanResult(reply, err) +} + +// Returns all the scripts in the database. +func (c *Client) ScriptScan() ([][]string, error) { + reply, err := c.DoOrSend("AI._SCRIPTSCAN", redis.Args{}, nil) + return ParseScanResult(reply, err) +} + +// ParseResult for AI._SCRIPTSCAN AI._MODELSCAN +func ParseScanResult(reply interface{}, err error) ([][]string, error) { + values, err := redis.Values(reply, err) + if err != nil { + return nil, err + } + + res := make([][]string, len(values), 2) + for i := 0; i < len(values); i++ { + if d, e := redis.Strings(values[i], nil); e == nil { + res[i] = d + } else { + log.Print("Error parsing ParseScanResult Reply: ", e) + res[i] = nil + } + } + return res, nil +} diff --git a/redisai/commands_test.go b/redisai/commands_test.go index ba6a4c4..d562ee0 100644 --- a/redisai/commands_test.go +++ b/redisai/commands_test.go @@ -978,3 +978,48 @@ func TestCommand_DagRunRO(t *testing.T) { }) } } + +func TestCommand_ModelScan(t *testing.T) { + c := createTestClient() + _, err := c.DoOrSend("FLUSHALL", redis.Args{}, nil) + // empty test + modelValues, err := c.ModelScan() + assert.Nil(t, err) + assert.Equal(t, 0, len(modelValues)) + + keyModel1 := "testModelScan" + model := implementations.NewModel("TF", "CPU") + model.SetInputs([]string{"transaction", "reference"}) + model.SetOutputs([]string{"output"}) + err = model.SetBlobFromFile("./../tests/test_data/creditcardfraud.pb") + assert.Nil(t, err) + err = c.ModelSetFromModel(keyModel1, model) + assert.Nil(t, err) + modelValues, err = c.ModelScan() + assert.Nil(t, err) + assert.Equal(t, 1, len(modelValues)) + assert.Equal(t, keyModel1, modelValues[0][0]) + assert.Empty(t, modelValues[0][1]) +} + +func TestCommand_ScriptScan(t *testing.T) { + c := createTestClient() + _, err := c.DoOrSend("FLUSHALL", redis.Args{}, nil) + // empty test + scriptValues, err := c.ScriptScan() + assert.Nil(t, err) + assert.Equal(t, 0, len(scriptValues)) + + keyScript1 := "test:ScriptScan:1" + scriptBin := "def bar(a, b):\n return a + b\n" + err = c.ScriptSet(keyScript1, DeviceCPU, scriptBin) + if err != nil { + t.Errorf("Error preparing for ScriptScan(), while issuing ScriptSet. error = %v", err) + return + } + scriptValues, err = c.ScriptScan() + assert.Nil(t, err) + assert.Equal(t, 1, len(scriptValues)) + assert.Equal(t, keyScript1, scriptValues[0][0]) + assert.Empty(t, scriptValues[0][1]) +}