diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4f81cc7907..6cce402ddf 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2078,7 +2078,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) - neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) + neg_inf = op.Constant(value=ir.tensor(query.dtype.min), dtype=query.dtype) attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),