diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index e310076d48..22565e4567 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -1,5 +1,5 @@ import { assert, it, describe } from "vitest"; -import { parseSafetensorsMetadata } from "./parse-safetensors-metadata"; +import { RE_SAFETENSORS_SHARD_FILE, parseSafetensorsMetadata } from "./parse-safetensors-metadata"; import { sum } from "../utils/sum"; describe("parseSafetensorsMetadata", () => { @@ -109,4 +109,19 @@ describe("parseSafetensorsMetadata", () => { assert.deepStrictEqual(parse.parameterCount, { BF16: 8_537_680_896 }); assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896); }); + + it("should detect sharded safetensors filename", async () => { + const safetensorsPath = "model00002-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00002-of-00072.safetensors + const match = safetensorsPath.match(RE_SAFETENSORS_SHARD_FILE); + + assert.strictEqual(RE_SAFETENSORS_SHARD_FILE.test(safetensorsPath), true); + assert.strictEqual(match?.[1], "00002"); + assert.strictEqual(match?.[2], "00072"); + + const safetensorsPathWithDash = "model-00002-of-00072.safetensors"; // https://huggingface.co/google/gemma-7b/blob/7aeedade2bfdf69adddb754cff0461e74541e436/model-00001-of-00004.safetensors + assert.strictEqual(RE_SAFETENSORS_SHARD_FILE.test(safetensorsPathWithDash), true); + + const safetensorsPathWithUnderscore = "model_00002-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00002-of-00072.safetensors + assert.strictEqual(RE_SAFETENSORS_SHARD_FILE.test(safetensorsPathWithUnderscore), true); + }); }); diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index 901f354072..0bcb6f7bd7 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -14,7 +14,7 @@ export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"; /// but in some situations safetensors weights have different filenames. export const RE_SAFETENSORS_FILE = /\.safetensors$/; export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/; -export const RE_SAFETENSORS_SHARD_FILE = /\d{5}-of-\d{5}\.safetensors$/; +export const RE_SAFETENSORS_SHARD_FILE = /[-_]?(\d{5})-of-(\d{5})\.safetensors$/; const PARALLEL_DOWNLOADS = 5; const MAX_HEADER_LENGTH = 25_000_000;