Skip to content

Conversation

@kunigori
Copy link

Adds SafeTensors-based serialization for PyTorch models (addresses #2532) and
implements metadata-driven loading to integrate cleanly with the materializer
workflow (per @bcdurak's feedback).

Changes

  • ✅ Add safetensors optional extra in pyproject.toml
  • ✅ Save state_dict to .safetensors when available; fallback to .pt with warning
  • ✅ Write minimal metadata.json (class_path, serialization_format)
  • ✅ Use TemporaryDirectory + copy_dir() for remote stores
  • load() always returns nn.Module
  • ✅ Backward compat: supports weights.pt, checkpoint.pt, and legacy entire_model.pt

New artifact layout

artifact_uri/
├─ weights.safetensors   # or weights.pt on fallback
└─ metadata.json         # class_path + format

Metadata

{
  "class_path": "my_package.models.MyModel",
  "serialization_format": "safetensors",
  "init_args": [],
  "init_kwargs": {},
  "factory_path": null
}

Why SafeTensors?

  • Security: Avoids pickle-based code execution risks
  • Performance: Faster, memory-mapped weight loads
  • Compatibility: Works with S3/GCS/Azure via artifact stores

Tests

Local run:

pytest tests/unit/integrations/pytorch/materializers/test_pytorch_module_materializer.py -v
# 4 passed in 1.88s

Coverage:

  • Round-trip with safetensors
  • Pickle fallback path
  • Metadata-driven load
  • Legacy formats (weights.pt, checkpoint.pt, entire_model.pt)
  • Clear error when safetensors extra is missing at load

Known limitations (Phase 1)

  • Zero-argument __init__() requirement: Models needing config should use
    a factory method (planned for Phase 2)

  • Legacy artifacts without metadata (weights.pt / checkpoint.pt) require:

  model = materializer.load(data_type=MyModel)
  • Legacy entire_model.pt is loaded and returned as a Module directly
    (no data_type needed)

Documentation

Happy to add a short guide covering why/how/limits/troubleshooting.
Which file should I update?

  • docs/book/component-guide/materializers/pytorch.md (materializer behavior)?
  • docs/book/integration-guide/pytorch.md (integration landing)?

Or would you prefer a new section?

Future work (separate PRs)

  • Phase 2: Support init_args / init_kwargs / factory functions
  • Phase 3: PyTorch Lightning materializer
  • Phase 4: HuggingFace Transformers support

Checklist

  • Tests pass locally
  • Code formatted (ruff check --fix + ruff format)
  • Also ran project scripts: bash scripts/format.sh and bash scripts/lint.sh
  • Type hints added (mypy clean)
  • Backward compatibility maintained
  • Rebased on develop
  • Documentation updated (pending guidance on location)
  • CLA signed

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@CLAassistant
Copy link

CLAassistant commented Nov 11, 2025

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


yusuke kunimitsu seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@kunigori kunigori changed the base branch from main to develop November 11, 2025 04:18
@schustmi
Copy link
Contributor

Hey @kunigori, thanks for the PR! Can you please base your changes on the develop branch and then also change the target of this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants