diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 05bac181ca..c9d870bd86 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -330,8 +330,9 @@ def aten_linalg_vector_norm( keepdim = False else: dim = op.Reshape(dim, op.Constant(value_ints=[-1])) - self = op.Abs(self) + if math.isinf(ord): + self = op.Abs(self) if ord > 0: return op.ReduceMax(self, dim, keepdims=keepdim) else: @@ -345,6 +346,9 @@ def aten_linalg_vector_norm( elif ord == 2.0: return op.ReduceL2(self, dim, keepdims=keepdim) else: + if ord < 0 or ord % 2 != 0: + # Not an even integer (could be odd, fractional or negative), use Abs + self = op.Abs(self) self_pow = op.Pow(self, ord) exp = op.CastLike(1 / ord, self) return op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), exp)