Skip to content

Conversation

@lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented Jul 30, 2025

When trying to run DS v3, 16B, I encounter the following cudaErrorAssert:

[rank0]:[titan] 2025-07-30 09:13:06,065 - root - INFO - Model deepseek_v3 16B size: 15,706,484,224 total parameters
[rank0]:[titan] 2025-07-30 09:13:06,079 - root - INFO - Applied full activation checkpointing to the model
[rank0]:[titan] 2025-07-30 09:13:06,176 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-30 09:13:06,520 - root - INFO - Peak FLOPS used for computing MFU: 2.250e+15
[rank0]:[titan] 2025-07-30 09:13:06,521 - root - INFO - CUDA memory usage for model: 8.80GiB(4.93%)
[rank0]:[titan] 2025-07-30 09:13:06,522 - root - WARNING - Warmup steps (200) exceed total training steps (20). Adjusting warmup steps to 20.
[rank0]:[titan] 2025-07-30 09:13:06,522 - root - WARNING - Warmup (20) + decay (16) steps exceed total training steps (20). Adjusting decay steps to 0.
[rank0]:[titan] 2025-07-30 09:13:06,522 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-30 09:13:06,522 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 20 (warmup 200)
[rank0]:[titan] 2025-07-30 09:13:06,522 - root - INFO - Training starts at step 1
[rank0]:/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [8027,0,0], thread: [64,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
[rank0]:/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [8027,0,0], thread: [65,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
[rank0]:/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [8027,0,0], thread: [66,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
[rank0]:/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [8027,0,0], thread: [67,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.

This root cause is that the spec'ed model vocab size of 102,400 is less than the tokenizers vocab of 128,000 + 818 added special tokens, or total vocab of 128815 (max token id + 1).

This PR thus:
updates model config to use the correct vocab_size of > 128815 in both args and init.py.
update: let's move to match 671B vocab size of 129280 for consistency, and this is divisible by 8 which addresses my perf concern below at the same time.

*note - there's a sep discussion to be had about rounding this size to next power of 2 or similar (div by 8) for better perf.

This let's users run successfully and avoid the crash out from the cudaErrorAssert.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 30, 2025
),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
vocab_size=128815,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a target checkpoint we could to load, e.g. DeepSeek MoE 16B?
If so we probably want to align with that, otherwise we should align with 671B.
cc: @wwwjn

Copy link
Contributor Author

@lessw2020 lessw2020 Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if there is a checkpoint for 16B, but for consistency, we can move to 129,280 (671B size) which is just ~465 or so tokens larger and now we have consistent sizing for all three variations.
Upddated PR to match 671B

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also addresses my point above about we want a size that is div by 8 for better perf, so extra bonus from synching this to 129280.

Copy link
Contributor

@wwwjn wwwjn Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree 129280 is better. I was thinking remove the 16B config completely cc @tianyu-l . But I saw one user trying to load the 16B moonlight model weights into our 16B model. The moonlight-16B model should be very close to deepseek-16B. Vocab size for this moonlight model is "vocab_size": 163840

@lessw2020 lessw2020 changed the title [deepseek] fix mismatch of model embeddings vs tokenizer size (128815), avoids cudaErrorAssert and run termination. [deepseek] fix mismatch of model embeddings vs tokenizer size (128815), sync to 671B vocab size, avoids cudaErrorAssert and run termination. Jul 30, 2025
@wwwjn
Copy link
Contributor

wwwjn commented Jul 30, 2025

To summary, I think the ideal way to solve this issue is: We use a correct tokenizer for 16B model, and keep the 16B model's parameters as it (As user might not realized the vocab size was changed by us and didn't match the DeepSeek-AI official params)

Because we don't have an official 16B model tokenizer, we could use our test tokenizer (vocab size 2000) in train_configs, or tokenizer from here https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/config.json

@ebsmothers
Copy link
Contributor

I agree with @wwwjn. Assuming the 16B model stays around, it would be better not to increase the model's vocab size. Otherwise loading checkpoints like https://huggingface.co/deepseek-ai/deepseek-moe-16b-chat will fail without some kind of custom handling of embedding weights.

@lessw2020
Copy link
Contributor Author

ah I see - I didn't realize there was an 'actual' 16B model with checkpoints available. (thought this was just a scaled down dev model to operate it within a single node for development/testing).
given that, I agree with @wwwjn, let's leave vocab size alone and just update to their official tokenizer and thus have a fully matching 16B.
I would vote also to keep 16B around - again for dev velocity, it's useful to ensure something meaningful but that fits within a single node. debug_model is often too small to see impact of changes.

@lessw2020
Copy link
Contributor Author

opened new PR which uses the 16 tokenizer here:
#1497

lessw2020 added a commit that referenced this pull request Jul 31, 2025
Based on the discussion in this PR
(#1495), the conclusion was to
ensure that 16B uses the proper tokenizer to avoid the cudaAssertError
in the current config which comes from mismatch between embeddings and
tokenizer vocab.

Thus, this PR;
1 - adds additional line to the readme for enabling users to pull the
16B-chat tokenizer,
2- updates the 16_toml config to point to the 16B tokenizer under
/assets/tokenizer/deepseek-moe-16b-chat

With that, the vocab size of 102400 already in the toml now works
flawlessly.


**Testing:**
run download tokenizer 
run 20 iters with 16B without issue.

<img width="1255" height="201" alt="Screenshot 2025-07-30 at 12 46
38 PM"
src="https://github.com/user-attachments/assets/e33556bf-51c6-4fa0-ab71-d1b02ef74d99"
/>
@lessw2020
Copy link
Contributor Author

closing as we resolved in the linked PR in a better way.

@lessw2020 lessw2020 closed this Jul 31, 2025
bentherien pushed a commit to bentherien/torchtitan_ that referenced this pull request Aug 5, 2025
…ch#1497)

Based on the discussion in this PR
(pytorch#1495), the conclusion was to
ensure that 16B uses the proper tokenizer to avoid the cudaAssertError
in the current config which comes from mismatch between embeddings and
tokenizer vocab.

Thus, this PR;
1 - adds additional line to the readme for enabling users to pull the
16B-chat tokenizer,
2- updates the 16_toml config to point to the 16B tokenizer under
/assets/tokenizer/deepseek-moe-16b-chat

With that, the vocab size of 102400 already in the toml now works
flawlessly.


**Testing:**
run download tokenizer 
run 20 iters with 16B without issue.

<img width="1255" height="201" alt="Screenshot 2025-07-30 at 12 46
38 PM"
src="https://github.com/user-attachments/assets/e33556bf-51c6-4fa0-ab71-d1b02ef74d99"
/>
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
…ch#1497)

Based on the discussion in this PR
(pytorch#1495), the conclusion was to
ensure that 16B uses the proper tokenizer to avoid the cudaAssertError
in the current config which comes from mismatch between embeddings and
tokenizer vocab.

Thus, this PR;
1 - adds additional line to the readme for enabling users to pull the
16B-chat tokenizer,
2- updates the 16_toml config to point to the 16B tokenizer under
/assets/tokenizer/deepseek-moe-16b-chat

With that, the vocab size of 102400 already in the toml now works
flawlessly.


**Testing:**
run download tokenizer 
run 20 iters with 16B without issue.

<img width="1255" height="201" alt="Screenshot 2025-07-30 at 12 46
38 PM"
src="https://github.com/user-attachments/assets/e33556bf-51c6-4fa0-ab71-d1b02ef74d99"
/>
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
…ch#1497)

Based on the discussion in this PR
(pytorch#1495), the conclusion was to
ensure that 16B uses the proper tokenizer to avoid the cudaAssertError
in the current config which comes from mismatch between embeddings and
tokenizer vocab.

Thus, this PR;
1 - adds additional line to the readme for enabling users to pull the
16B-chat tokenizer,
2- updates the 16_toml config to point to the 16B tokenizer under
/assets/tokenizer/deepseek-moe-16b-chat

With that, the vocab size of 102400 already in the toml now works
flawlessly.


**Testing:**
run download tokenizer 
run 20 iters with 16B without issue.

<img width="1255" height="201" alt="Screenshot 2025-07-30 at 12 46
38 PM"
src="https://github.com/user-attachments/assets/e33556bf-51c6-4fa0-ab71-d1b02ef74d99"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants