@@ -1538,10 +1538,21 @@ def elastic_image_tensor(
1538
1538
1539
1539
device = image .device
1540
1540
dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
1541
+
1542
+ # Patch: elastic transform should support (cpu,f16) input
1543
+ is_cpu_half = device .type == "cpu" and dtype == torch .float16
1544
+ if is_cpu_half :
1545
+ image = image .to (torch .float32 )
1546
+ dtype = torch .float32
1547
+
1541
1548
# We are aware that if input image dtype is uint8 and displacement is float64 then
1542
1549
# displacement will be casted to float32 and all computations will be done with float32
1543
1550
# We can fix this later if needed
1544
1551
1552
+ expected_shape = (1 ,) + shape [- 2 :] + (2 ,)
1553
+ if expected_shape != displacement .shape :
1554
+ raise ValueError (f"Argument displacement shape should be { expected_shape } , but given { displacement .shape } " )
1555
+
1545
1556
if ndim > 4 :
1546
1557
image = image .reshape ((- 1 ,) + shape [- 3 :])
1547
1558
needs_unsquash = True
@@ -1561,6 +1572,9 @@ def elastic_image_tensor(
1561
1572
if needs_unsquash :
1562
1573
output = output .reshape (shape )
1563
1574
1575
+ if is_cpu_half :
1576
+ output = output .to (torch .float16 )
1577
+
1564
1578
return output
1565
1579
1566
1580
@@ -1676,6 +1690,9 @@ def elastic(
1676
1690
if not torch .jit .is_scripting ():
1677
1691
_log_api_usage_once (elastic )
1678
1692
1693
+ if not isinstance (displacement , torch .Tensor ):
1694
+ raise TypeError ("Argument displacement should be a Tensor" )
1695
+
1679
1696
if torch .jit .is_scripting () or is_simple_tensor (inpt ):
1680
1697
return elastic_image_tensor (inpt , displacement , interpolation = interpolation , fill = fill )
1681
1698
elif isinstance (inpt , datapoints ._datapoint .Datapoint ):
0 commit comments