Skip to content

Commit a2cd158

Browse files
committed
feat(crypt): Add fast encryption
1 parent 53f26d4 commit a2cd158

File tree

13 files changed

+1705
-493
lines changed

13 files changed

+1705
-493
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,8 @@ jobs:
2222
- name: Install Redis
2323
run: sudo apt-get install -y redis-server
2424

25+
- name: Install libsodium
26+
run: sudo apt-get install -y libsodium23
27+
2528
- name: Run tests
2629
run: python -m unittest discover tests/ --verbose

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [Unreleased]
9+
10+
### Added
11+
12+
- Tensor encryption
13+
- Encrypts all tensor weights in a file with minimal overhead
14+
- Doesn't encrypt tensor metadata, such as:
15+
- Tensor name
16+
- Tensor `dtype`
17+
- Tensor shape & size
18+
- Requires an up-to-date version of `libsodium`
19+
- Use `apt-get install libsodium23` on Ubuntu or Debian
20+
- On other platforms, follow the
21+
[installation instructions from the libsodium documentation](https://doc.libsodium.org/installation)
22+
- Takes up less than 500 KiB once installed
23+
- Uses a parallelized version of XSalsa20-Poly1305 as its encryption algorithm
24+
- Splits each tensor's weights into ≤ 2 MiB chunks, encrypted separately
25+
- Example usage: see [examples/encryption.py](examples/encryption.py)
26+
827
## [2.6.0] - 2023-10-30
928

1029
### Added
@@ -220,6 +239,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
220239
- `get_gpu_name`
221240
- `no_init_or_tensor`
222241

242+
[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.6.0...HEAD
223243
[2.6.0]: https://github.com/coreweave/tensorizer/compare/v2.5.1...v2.6.0
224244
[2.5.1]: https://github.com/coreweave/tensorizer/compare/v2.5.0...v2.5.1
225245
[2.5.0]: https://github.com/coreweave/tensorizer/compare/v2.4.0...v2.5.0

examples/encryption.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
import tempfile
3+
import time
4+
5+
import torch
6+
from transformers import AutoConfig, AutoModelForCausalLM
7+
8+
from tensorizer import (
9+
DecryptionParams,
10+
EncryptionParams,
11+
TensorDeserializer,
12+
TensorSerializer,
13+
)
14+
from tensorizer.utils import no_init_or_tensor
15+
16+
model_ref = "EleutherAI/gpt-neo-2.7B"
17+
18+
19+
def original_model(ref) -> torch.nn.Module:
20+
return AutoModelForCausalLM.from_pretrained(ref)
21+
22+
23+
def empty_model(ref) -> torch.nn.Module:
24+
config = AutoConfig.from_pretrained(ref)
25+
with no_init_or_tensor():
26+
return AutoModelForCausalLM.from_config(config)
27+
28+
29+
# Set a strong string or bytes passphrase here
30+
passphrase: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD", "") or input(
31+
"Passphrase to use for encryption: "
32+
)
33+
34+
fd, path = tempfile.mkstemp(prefix="encrypted-tensors")
35+
36+
try:
37+
# Encrypt a model during serialization
38+
encryption_params = EncryptionParams.from_passphrase_fast(passphrase)
39+
40+
model = original_model(model_ref)
41+
serialization_start = time.monotonic()
42+
43+
serializer = TensorSerializer(path, encryption=encryption_params)
44+
serializer.write_module(model)
45+
serializer.close()
46+
47+
serialization_end = time.monotonic()
48+
del model
49+
50+
# Then decrypt it again during deserialization
51+
decryption_params = DecryptionParams.from_passphrase(passphrase)
52+
53+
model = empty_model(model_ref)
54+
deserialization_start = time.monotonic()
55+
56+
deserializer = TensorDeserializer(
57+
path, encryption=decryption_params, plaid_mode=True
58+
)
59+
deserializer.load_into_module(model)
60+
deserializer.close()
61+
62+
deserialization_end = time.monotonic()
63+
del model
64+
finally:
65+
os.close(fd)
66+
os.unlink(path)
67+
68+
69+
def print_speed(prefix, start, end, size):
70+
mebibyte = 1 << 20
71+
gibibyte = 1 << 30
72+
duration = end - start
73+
rate = size / duration
74+
print(
75+
f"{prefix} {size / gibibyte:.2f} GiB model in {duration:.2f} seconds,"
76+
f" {rate / mebibyte:.2f} MiB/s"
77+
)
78+
79+
80+
print_speed(
81+
"Serialized and encrypted",
82+
serialization_start,
83+
serialization_end,
84+
serializer.total_tensor_bytes,
85+
)
86+
87+
print_speed(
88+
"Deserialized encrypted",
89+
deserialization_start,
90+
deserialization_end,
91+
deserializer.total_tensor_bytes,
92+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies = [
1717
"boto3>=1.26.0",
1818
"redis>=5.0.0",
1919
"hiredis>=2.2.0",
20-
"pynacl>=1.5.0",
20+
"libnacl>=2.1.0"
2121
]
2222
classifiers = [
2323
"Programming Language :: Python :: 3",

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ numpy>=1.19.5
33
protobuf>=3.19.5
44
psutil>=5.9.4
55
boto3>=1.26.0
6+
redis==5.0.0
67
hiredis
7-
redis==5.0.0
8+
libnacl>=2.1.0

tensorizer/_NumpyTensor.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,74 @@
1313
8: torch.int64,
1414
}
1515

16-
# torch types with no numpy equivalents
17-
# i.e. the only ones that need to be opaque
18-
# Uses a comprehension to filter out any dtypes
19-
# that don't exist in older torch versions
20-
_ASYMMETRIC_TYPES = {
21-
getattr(torch, t)
16+
# Listing of types from a static copy of:
17+
# tuple(
18+
# dict.fromkeys(
19+
# str(t)
20+
# for t in vars(torch).values()
21+
# if isinstance(t, torch.dtype)
22+
# )
23+
# )
24+
_ALL_TYPES = {
25+
f"torch.{t}": v
2226
for t in (
23-
"bfloat16",
24-
"quint8",
27+
"uint8",
28+
"int8",
29+
"int16",
30+
"int32",
31+
"int64",
32+
"float16",
33+
"float32",
34+
"float64",
35+
"complex32",
36+
"complex64",
37+
"complex128",
38+
"bool",
2539
"qint8",
40+
"quint8",
2641
"qint32",
42+
"bfloat16",
2743
"quint4x2",
2844
"quint2x4",
29-
"complex32",
3045
)
31-
if hasattr(torch, t)
46+
if isinstance(v := getattr(torch, t, None), torch.dtype)
47+
}
48+
49+
# torch types with no numpy equivalents
50+
# i.e. the only ones that need to be opaque
51+
# Uses a comprehension to filter out any dtypes
52+
# that don't exist in older torch versions
53+
_ASYMMETRIC_TYPES = {
54+
_ALL_TYPES[t]
55+
for t in {
56+
"torch.bfloat16",
57+
"torch.quint8",
58+
"torch.qint8",
59+
"torch.qint32",
60+
"torch.quint4x2",
61+
"torch.quint2x4",
62+
"torch.complex32",
63+
}
64+
& _ALL_TYPES.keys()
3265
}
3366

3467
# These types aren't supported yet because they require supplemental
3568
# quantization parameters to deserialize correctly
3669
_UNSUPPORTED_TYPES = {
37-
getattr(torch, t)
38-
for t in (
39-
"quint8",
40-
"qint8",
41-
"qint32",
42-
"quint4x2",
43-
"quint2x4",
44-
)
45-
if hasattr(torch, t)
70+
_ALL_TYPES[t]
71+
for t in {
72+
"torch.quint8",
73+
"torch.qint8",
74+
"torch.qint32",
75+
"torch.quint4x2",
76+
"torch.quint2x4",
77+
}
78+
& _ALL_TYPES.keys()
4679
}
4780

48-
_DECODE_MAPPING = {str(t): t for t in _ASYMMETRIC_TYPES}
81+
_DECODE_MAPPING = {
82+
k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES
83+
}
4984

5085

5186
class _NumpyTensor(NamedTuple):
@@ -85,14 +120,12 @@ def from_buffer(
85120
buffer=buffer,
86121
offset=offset,
87122
)
88-
return cls(data=data,
89-
numpy_dtype=numpy_dtype,
90-
torch_dtype=torch_dtype)
123+
return cls(data=data, numpy_dtype=numpy_dtype, torch_dtype=torch_dtype)
91124

92125
@classmethod
93-
def from_tensor(cls,
94-
tensor: Union[torch.Tensor,
95-
torch.nn.Module]) -> "_NumpyTensor":
126+
def from_tensor(
127+
cls, tensor: Union[torch.Tensor, torch.nn.Module]
128+
) -> "_NumpyTensor":
96129
"""
97130
Converts a torch tensor into a `_NumpyTensor`.
98131
May use an opaque dtype for the numpy array stored in

0 commit comments

Comments
 (0)