Skip to content

Commit 77ca57d

Browse files
authored
Fix Safe Load for NF4 (#1241)
1 parent ccd883b commit 77ca57d

File tree

4 files changed

+111
-45
lines changed

4 files changed

+111
-45
lines changed

ruff.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ include = [
1212
"test/dtypes/test_affine_quantized_float.py",
1313
"test/dtypes/test_nf4.py",
1414
"test/prototype/low_bit_optim/**.py",
15+
"torchao/utils.py",
16+
1517
]
1618

1719
lint.ignore = ["E731"]

test/dtypes/test_nf4.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,10 @@ def test_load_from_state_dicts(self, dtype: torch.dtype):
170170
assert base_mod.param.block_size == 32
171171
assert base_mod.param.scaler_block_size == 2
172172

173-
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
174173
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
175174
def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
176175
"""Tests loading to and from different module state dicts"""
177-
input_tensor = torch.rand(64, device="cuda", dtype=dtype)
176+
input_tensor = torch.rand(64, dtype=dtype)
178177
base_mod = self.TestMod(input_tensor, 32, 2)
179178
state_dict = base_mod.state_dict()
180179
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
@@ -184,11 +183,10 @@ def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
184183
assert other_mod.param.block_size == 32
185184
assert other_mod.param.scaler_block_size == 2
186185

187-
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
188186
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
189187
def test_load_from_nf4_diff_meta(self, dtype: torch.dtype):
190188
"""Tests loading to and from different module state dicts"""
191-
input_tensor = torch.rand(128, device="cuda", dtype=dtype)
189+
input_tensor = torch.rand(128, dtype=dtype)
192190
base_mod = self.TestMod(input_tensor, 32, 2)
193191
state_dict = base_mod.state_dict()
194192
saved_state_dict = self.save_state_dict_to_buffer(state_dict)

torchao/dtypes/nf4tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from torch._prims_common import make_contiguous_strides_for
1111
from torch.distributed.device_mesh import DeviceMesh
1212

13+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
14+
1315
aten = torch.ops.aten
1416

1517
c10d_functional = torch.ops.c10d_functional
@@ -1043,3 +1045,7 @@ def nf4_constructor(
10431045
quantized_data,
10441046
nf4,
10451047
)
1048+
1049+
1050+
if TORCH_VERSION_AT_LEAST_2_5:
1051+
torch.serialization.add_safe_globals([NF4Tensor])

0 commit comments

Comments
 (0)