|
| 1 | +import os |
| 2 | +import tempfile |
| 3 | +import time |
| 4 | + |
| 5 | +import torch |
| 6 | +from transformers import AutoConfig, AutoModelForCausalLM |
| 7 | + |
| 8 | +from tensorizer import ( |
| 9 | + DecryptionParams, |
| 10 | + EncryptionParams, |
| 11 | + TensorDeserializer, |
| 12 | + TensorSerializer, |
| 13 | +) |
| 14 | +from tensorizer.utils import no_init_or_tensor |
| 15 | + |
| 16 | +model_ref = "EleutherAI/gpt-neo-2.7B" |
| 17 | + |
| 18 | + |
| 19 | +def original_model(ref) -> torch.nn.Module: |
| 20 | + return AutoModelForCausalLM.from_pretrained(ref) |
| 21 | + |
| 22 | + |
| 23 | +def empty_model(ref) -> torch.nn.Module: |
| 24 | + config = AutoConfig.from_pretrained(ref) |
| 25 | + with no_init_or_tensor(): |
| 26 | + return AutoModelForCausalLM.from_config(config) |
| 27 | + |
| 28 | + |
| 29 | +# Set a strong string or bytes passphrase here |
| 30 | +passphrase: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD", "") or input( |
| 31 | + "Passphrase to use for encryption: " |
| 32 | +) |
| 33 | + |
| 34 | +fd, path = tempfile.mkstemp(prefix="encrypted-tensors") |
| 35 | + |
| 36 | +try: |
| 37 | + # Encrypt a model during serialization |
| 38 | + encryption_params = EncryptionParams.from_passphrase_fast(passphrase) |
| 39 | + |
| 40 | + model = original_model(model_ref) |
| 41 | + serialization_start = time.monotonic() |
| 42 | + |
| 43 | + serializer = TensorSerializer(path, encryption=encryption_params) |
| 44 | + serializer.write_module(model) |
| 45 | + serializer.close() |
| 46 | + |
| 47 | + serialization_end = time.monotonic() |
| 48 | + del model |
| 49 | + |
| 50 | + # Then decrypt it again during deserialization |
| 51 | + decryption_params = DecryptionParams.from_passphrase(passphrase) |
| 52 | + |
| 53 | + model = empty_model(model_ref) |
| 54 | + deserialization_start = time.monotonic() |
| 55 | + |
| 56 | + deserializer = TensorDeserializer( |
| 57 | + path, encryption=decryption_params, plaid_mode=True |
| 58 | + ) |
| 59 | + deserializer.load_into_module(model) |
| 60 | + deserializer.close() |
| 61 | + |
| 62 | + deserialization_end = time.monotonic() |
| 63 | + del model |
| 64 | +finally: |
| 65 | + os.close(fd) |
| 66 | + os.unlink(path) |
| 67 | + |
| 68 | + |
| 69 | +def print_speed(prefix, start, end, size): |
| 70 | + mebibyte = 1 << 20 |
| 71 | + gibibyte = 1 << 30 |
| 72 | + duration = end - start |
| 73 | + rate = size / duration |
| 74 | + print( |
| 75 | + f"{prefix} {size / gibibyte:.2f} GiB model in {duration:.2f} seconds," |
| 76 | + f" {rate / mebibyte:.2f} MiB/s" |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +print_speed( |
| 81 | + "Serialized and encrypted", |
| 82 | + serialization_start, |
| 83 | + serialization_end, |
| 84 | + serializer.total_tensor_bytes, |
| 85 | +) |
| 86 | + |
| 87 | +print_speed( |
| 88 | + "Deserialized encrypted", |
| 89 | + deserialization_start, |
| 90 | + deserialization_end, |
| 91 | + deserializer.total_tensor_bytes, |
| 92 | +) |
0 commit comments