@@ -426,6 +426,61 @@ def test_pytorch_scriptrun(env):
426426 values2 = con2 .execute_command ('AI.TENSORGET' , 'c' , 'VALUES' )
427427 env .assertEqual (values2 , values )
428428
429+
430+ def test_pytorch_scriptrun_variadic (env ):
431+ if not TEST_PT :
432+ env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
433+ return
434+
435+ con = env .getConnection ()
436+
437+ test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
438+ script_filename = os .path .join (test_data_path , 'script.txt' )
439+
440+ with open (script_filename , 'rb' ) as f :
441+ script = f .read ()
442+
443+ ret = con .execute_command ('AI.SCRIPTSET' , 'myscript' , DEVICE , 'TAG' , 'version1' , 'SOURCE' , script )
444+ env .assertEqual (ret , b'OK' )
445+
446+ ret = con .execute_command ('AI.TENSORSET' , 'a' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
447+ env .assertEqual (ret , b'OK' )
448+ ret = con .execute_command ('AI.TENSORSET' , 'b1' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
449+ env .assertEqual (ret , b'OK' )
450+ ret = con .execute_command ('AI.TENSORSET' , 'b2' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
451+ env .assertEqual (ret , b'OK' )
452+
453+ ensureSlaveSynced (con , env )
454+
455+ for _ in range ( 0 ,100 ):
456+ ret = con .execute_command ('AI.SCRIPTRUN' , 'myscript' , 'bar_variadic' , 'INPUTS' , 'a' , '$' , 'b1' , 'b2' , 'OUTPUTS' , 'c' )
457+ env .assertEqual (ret , b'OK' )
458+
459+ ensureSlaveSynced (con , env )
460+
461+ info = con .execute_command ('AI.INFO' , 'myscript' )
462+ info_dict_0 = info_to_dict (info )
463+
464+ env .assertEqual (info_dict_0 ['key' ], 'myscript' )
465+ env .assertEqual (info_dict_0 ['type' ], 'SCRIPT' )
466+ env .assertEqual (info_dict_0 ['backend' ], 'TORCH' )
467+ env .assertEqual (info_dict_0 ['tag' ], 'version1' )
468+ env .assertTrue (info_dict_0 ['duration' ] > 0 )
469+ env .assertEqual (info_dict_0 ['samples' ], - 1 )
470+ env .assertEqual (info_dict_0 ['calls' ], 100 )
471+ env .assertEqual (info_dict_0 ['errors' ], 0 )
472+
473+ values = con .execute_command ('AI.TENSORGET' , 'c' , 'VALUES' )
474+ env .assertEqual (values , [b'4' , b'6' , b'4' , b'6' ])
475+
476+ ensureSlaveSynced (con , env )
477+
478+ if env .useSlaves :
479+ con2 = env .getSlaveConnection ()
480+ values2 = con2 .execute_command ('AI.TENSORGET' , 'c' , 'VALUES' )
481+ env .assertEqual (values2 , values )
482+
483+
429484def test_pytorch_scriptrun_errors (env ):
430485 if not TEST_PT :
431486 env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
@@ -528,6 +583,66 @@ def test_pytorch_scriptrun_errors(env):
528583 env .assertEqual (type (exception ), redis .exceptions .ResponseError )
529584
530585
586+ def test_pytorch_scriptrun_errors (env ):
587+ if not TEST_PT :
588+ env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
589+ return
590+
591+ con = env .getConnection ()
592+
593+ test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
594+ script_filename = os .path .join (test_data_path , 'script.txt' )
595+
596+ with open (script_filename , 'rb' ) as f :
597+ script = f .read ()
598+
599+ ret = con .execute_command ('AI.SCRIPTSET' , 'ket' , DEVICE , 'TAG' , 'asdf' , 'SOURCE' , script )
600+ env .assertEqual (ret , b'OK' )
601+
602+ ret = con .execute_command ('AI.TENSORSET' , 'a' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
603+ env .assertEqual (ret , b'OK' )
604+ ret = con .execute_command ('AI.TENSORSET' , 'b' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
605+ env .assertEqual (ret , b'OK' )
606+
607+ ensureSlaveSynced (con , env )
608+
609+ # ERR Variadic input key is empty
610+ try :
611+ con .execute_command ('DEL' , 'EMPTY' )
612+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'a' , '$' , 'EMPTY' , 'b' , 'OUTPUTS' , 'c' )
613+ except Exception as e :
614+ exception = e
615+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
616+ env .assertEqual ("tensor key is empty" , exception .__str__ ())
617+
618+ # ERR Variadic input key not tensor
619+ try :
620+ con .execute_command ('SET' , 'NOT_TENSOR' , 'BAR' )
621+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'a' , '$' , 'NOT_TENSOR' , 'b' , 'OUTPUTS' , 'c' )
622+ except Exception as e :
623+ exception = e
624+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
625+ env .assertEqual ("WRONGTYPE Operation against a key holding the wrong kind of value" , exception .__str__ ())
626+
627+ try :
628+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'b' , '$' , 'OUTPUTS' , 'c' )
629+ except Exception as e :
630+ exception = e
631+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
632+
633+ try :
634+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'b' , '$' , 'OUTPUTS' )
635+ except Exception as e :
636+ exception = e
637+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
638+
639+ try :
640+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , '$' , 'OUTPUTS' )
641+ except Exception as e :
642+ exception = e
643+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
644+
645+
531646def test_pytorch_scriptinfo (env ):
532647 if not TEST_PT :
533648 env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
0 commit comments