Skip to content

Commit ca99c1c

Browse files
authored
fix nf4 test that is failing in CI (#3216)
1 parent 158e72c commit ca99c1c

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

torchao/dtypes/nf4tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,26 @@ def _(*args, **kwargs):
11291129
return out
11301130

11311131

1132+
@implements_torch_function(torch.Tensor.view_as)
1133+
def function_view_as(*args, **kwargs):
1134+
"""Handle view_as for NF4Tensor.
1135+
1136+
When view_as is called (typically by autograd internals), we need to return
1137+
a fresh NF4Tensor without autograd metadata to avoid conflicts.
1138+
"""
1139+
tensor = args[0]
1140+
1141+
# Create a new NF4Tensor with detached inner tensors to avoid autograd conflicts
1142+
updated_attrs = {}
1143+
tensor_attrs, _ = tensor.__tensor_flatten__()
1144+
for attr in tensor_attrs:
1145+
inner_tensor = getattr(tensor, attr)
1146+
# Detach to create a fresh tensor without autograd metadata
1147+
updated_attrs[attr] = inner_tensor.detach()
1148+
1149+
return NF4Tensor(*construct_nf4_args(tensor, updated_attrs))
1150+
1151+
11321152
@torch._dynamo.allow_in_graph
11331153
def nf4_constructor(
11341154
tensor_meta: SubclassTensorArgs,

0 commit comments

Comments
 (0)