-
Notifications
You must be signed in to change notification settings - Fork 694
Add vanilla DeepSpeech model #1399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for the contribution! :) Before moving forward, we need to determine how to validate the model implementation, especially if the description implies that it should reproduce an existing paper. Here are strategies we have used in the past:
Do you have suggestions on how to do so? EDIT: cc #446 (comment) |
Thanks for response.
Could you please share examples of pipeline as you did for wav2letter and wavernn? I could follow the same process for What do you think? |
And thanks for contributing! :)
Ok, thanks for clarifying :) we'll just need to be careful about the phrasing in the description of the model to covey the idea that this might differ from the original implementation.
Sure, they are in the torchaudio repo: wav2letter, wavernn. (I've also updated the comment above with them.)
If the data used for is not freely available to produce the pre-trained weights, then having a colab training pipeline using freely available data would be useful. The DeepSpeech2 paper quotes 8% WER on LibriSpeech for DS1 in Table 13, but that's using other data too for training. This discussion gets 21% with only LibriSpeech, which is of course far from 8% mentioned (or SOTA). What numbers would you expect to get by doing the experiment you suggest in colab? Note that the wav2letter pipeline reports only the CER and gets 15% CER, and does not have a language model. |
Proposed model in this PR matches the original DeepSpeech paper.
Is it ok that I add train pipeline in this PR as you did for wav2letter and wavernn?
I found that TPU is impractical so far to train models on CTCLoss so I can't train DeepSpeech using The mentioned DS1 implementation says that it achieves In summary, here are thing what I can do:
@vincentqb |
This would be a good way of showing the convergence :) (by the way, once we are convinced of the convergence, I'm leaving the door open to only merging the model without the pipeline at the end for now to keep the PR small.)
Having the pipeline here for train/validation also allows us to validate the pipeline, and potentially run it on our side with the other data you suggest.
Great! (I'm reading: |
If you can share the training script on early stage, feel free to share I can try it in our training cluster as well. Also if you are writing a training script from scratch, please refer to the distributed training utility and entry point in source separation script. This one achieves a much better code organization and decoupling of training logic and setup. |
Just a quick update. I'll update script for training DeepSpeech with dist.utility as you proposed. But I won't able to check it manually because I have a single GPU. So next steps are:
Sounds good for you? |
UPD. In the implementation from Mozilla's DeepSpeech they used 75 epochs, 0.05 dropout and beam decoder. @vincentqb |
Thanks for sending the script and logs :) Just to clarify: in your log, the training reaches 130.5% WER and validation 33.3% WER, right? I did a quick run of your script, and get a reduction from about 60 %WER to 33.6 %WER for validation, see log. The training batch goes from around 100 %WER to 21.9 %WER in 14 epochs. I'll let it run a bit see if the validation continue to go down. Do you have a version of DeepSpeech with dropout as done by Mozilla? If you want to use a ctc beam search decoder, I'd suggest using awni's. Given that there is no language model here, this should be the same as doing argmax and removing duplicates and blank. Does Mozilla do something different here? |
Actually, I calculated WER in training incorrectly, so did not share it intentionally, then I fixed WER and measured WER for validation on
Are you asking about probability coefficients for dropout?
Thanks for sharing Awni's gist. I'll definitely take a look. To be honest I'm not into details how Mozilla's implementation works but at the first look it's based on https://github.com/parlance/ctcdecode so the principle must be the same. How can I help you to move it forward? |
Regarding the dropout:
In the implementation from Mozilla they specify default value I can tweak the model implementation to accept dropout coefficients in the constructor (if it's needed):
Anyway, I would like to see your thoughts about that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I calculated WER in training incorrectly, so did not share it intentionally, then I fixed WER and measured WER for validation on
dev-clean
and got 33% WER in 8 epochs.
Is the code and log in your previous comment corrected then?
Do you have a version of DeepSpeech with dropout as done by Mozilla?
Are you asking about probability coefficients for dropout?
I was asking for a code to copy paste and run with dropout so I can do a long run with it :) For now, I would not include it in the version for the PR.
When I look at Mozilla, I see layer 1 with 0.05 dropout as you mentioned, but also layer 2, 3, and 6 also defaulting to that value of 0.05.
Thanks for sharing Awni's gist. I'll definitely take a look. To be honest I'm not into details how Mozilla's implementation works but at the first look it's based on https://github.com/parlance/ctcdecode so the principle must be the same.
How can I help you to move it forward?
Thanks for pointing out the decoder that Mozilla uses. We won't include this as part of the PR.
I've had to restart the long run. I'll post the result here.
Here's a quick update for fun :) but I'll leave the code running longer. EDIT: see updated log, with a best 31.8 %WER for validation. |
Thank you so much for sharing it. It's just a pleasure to see this graphs :) (especially when you can't reproduce them manually). The best validation WER so far (what I've found in the attached log) is
For lower validation WER: |
I've ran the model also including train-other-500, see log. Best validation: 28.3 %WER. We can see that the training continues to overfit, which is great in this context. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this :) I gave a few comments around the organization of the model.
torchaudio/models/deepspeech.py
Outdated
Args: | ||
x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features). | ||
Returns: | ||
Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the docstring :) as mentioned in comment above, we want to keep batch first for the output too, e.g. wav2letter, conv_tasnet, wavernn.
Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes). | |
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes). |
Making this change might not be ideal and means the pipeline (1) needs another transpose to be compatible with CTC, (2) this follows the RNN convention though batch_first
could be used, (3) might get a performance hit (though this should be measured). Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All suggestions look reasonable and I have already addressed your comments, thank you for sharing it with me.
However, I'm still don't follow your 3rd point. Could you please elaborate it a bit, why keeping batch_first
can bring a performance hit?
In the mentioned example you are using simple DataParallel
and here is my understanding of how it works (on 2 devices for simplicity):
- the model is copied into two devices
- dataloader splits a batch into 2 chunks
- then it feeds each model by each chunk during forward pass
- then, using the outputs of models, each device calculates ctc loss, the loss values are copied into single tensor (on RAM) and reduction is performed on cpu.
- on a backward pass, the gradients are computed on 2 devices and then they are 'all-reduce'd on cpu.
If I understood the process correctly, what's the intuition of keeping batch_first
in the model's output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In summary: let's have batch first for the output too for readability and consistency with the other models :)
Could you please elaborate it a bit, why keeping
batch_first
can bring a performance hit?
The question about performance that I was raising was whether the forward method in RNN
is "faster" when batch is first or not. For instance, this small user test here does not seem to find a significant difference in speed. Based on this small test, let's not worry about this here then.
In the mentioned example you are using simple
DataParallel
and here is my understanding of how it works (on 2 devices for simplicity):
DataParallel applies only to the computation of the model:
- the model is copied to the two devices
- the batch is split in to along the first dimension and sent to each devices
- the forward is computed on each device
- the output tensors are concatenated along the first dimension
- the loss function is then computed on that new tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the detailed explanation. Let me know if anything is wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I just did a re-run with the latest version, and got a slightly better best validation of 0.259 at epoch 22.
I'll push a commit to change the names of the parameter to be more consistent with WaveRNN and the convention described in the README, wait for the tests to run, and merge this PR. Thanks for working on this! :)
updated the naming convention, waiting for the tests to run :) |
5f8a78d
to
7b59009
Compare
(rebased) |
(I confirm the failures are unrelated to this PR. In particular CodeQL failure is also occurring on this one.) |
Merged, thanks again for working on this! |
Co-authored-by: Vincent Quenneville-Belair <[email protected]>
DeepSpeech
DeepSpeech
DeepSpeech
according to its paper Deep Speech: Scaling up end-to-end speech recognitionRelates to #446