@@ -317,5 +317,30 @@ def test_draw_keypoints_errors():
317
317
utils .draw_keypoints (image = img , keypoints = invalid_keypoints )
318
318
319
319
320
+ def test_flow_to_image ():
321
+ h , w = 100 , 100
322
+ flow = torch .meshgrid (torch .arange (h ), torch .arange (w ), indexing = "ij" )
323
+ flow = torch .stack (flow [::- 1 ], dim = 0 ).float ()
324
+ flow [0 ] -= h / 2
325
+ flow [1 ] -= w / 2
326
+ img = utils .flow_to_image (flow )
327
+ path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "assets" , "expected_flow.pt" )
328
+ expected_img = torch .load (path , map_location = "cpu" )
329
+ assert_equal (expected_img , img )
330
+
331
+
332
+ def test_flow_to_image_errors ():
333
+ wrong_flow1 = torch .full ((3 , 10 , 10 ), 0 , dtype = torch .float )
334
+ wrong_flow2 = torch .full ((2 , 10 ), 0 , dtype = torch .float )
335
+ wrong_flow3 = torch .full ((2 , 10 , 30 ), 0 , dtype = torch .int )
336
+
337
+ with pytest .raises (ValueError , match = "Input flow should have shape" ):
338
+ utils .flow_to_image (flow = wrong_flow1 )
339
+ with pytest .raises (ValueError , match = "Input flow should have shape" ):
340
+ utils .flow_to_image (flow = wrong_flow2 )
341
+ with pytest .raises (ValueError , match = "Flow should be of dtype torch.float" ):
342
+ utils .flow_to_image (flow = wrong_flow3 )
343
+
344
+
320
345
if __name__ == "__main__" :
321
346
pytest .main ([__file__ ])
0 commit comments