diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index f62a4f27a1..8184fd5eba 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2069,9 +2069,9 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( query_scaled = op.Mul(query, op.Sqrt(scale)) key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) - attn_mask = op.Where( - attn_mask, op.Constant(value_float=0.0), op.Constant(value_float=-float("inf")) - ) + zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) + neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) + attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), axis=-1,