Skip to content

Add vanilla DeepSpeech model #1399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 11, 2021
Merged

Conversation

discort
Copy link
Contributor

@discort discort commented Mar 18, 2021

Relates to #446

@vincentqb
Copy link
Contributor

vincentqb commented Mar 18, 2021

Thanks for the contribution! :) Before moving forward, we need to determine how to validate the model implementation, especially if the description implies that it should reproduce an existing paper. Here are strategies we have used in the past:

Do you have suggestions on how to do so?

EDIT: cc #446 (comment)

@discort
Copy link
Contributor Author

discort commented Mar 19, 2021

Thanks for response.
@vincentqb

  • I used wav2letter implementation as an example of this PR.
  • According to the paper, DeepSpeech was trained on 7380 hours on audio including private and paid datasets so unfortunately, I'm not able to reproduce exact the same result and provide the weights for this model as you did for MobileNetV3.
  • I've made a quick search and didn't find existing precise implementations of DeepSpeech. I've seen only numerous implementations of DeepSpeech2 () or DeepSpeech with custom modifications (are not mentioned in the paper), for instance here.

Could you please share examples of pipeline as you did for wav2letter and wavernn? I could follow the same process for DeepSpeech. Or I could make a collab file and train the model on a simple dataset, let's say dev-clean librispeech, train it on TPU and share the result here.

What do you think?

@vincentqb
Copy link
Contributor

Thanks for response. @vincentqb

And thanks for contributing! :)

  • I've made a quick search and didn't find existing precise implementations of DeepSpeech. I've seen only numerous implementations of DeepSpeech2 () or DeepSpeech with custom modifications (are not mentioned in the paper), for instance here.

Ok, thanks for clarifying :) we'll just need to be careful about the phrasing in the description of the model to covey the idea that this might differ from the original implementation.

Could you please share examples of pipeline as you did for wav2letter and wavernn? I could follow the same process for DeepSpeech.

Sure, they are in the torchaudio repo: wav2letter, wavernn. (I've also updated the comment above with them.)

Or I could make a collab file and train the model on a simple dataset, let's say dev-clean librispeech, train it on TPU and share the result here.

If the data used for is not freely available to produce the pre-trained weights, then having a colab training pipeline using freely available data would be useful. The DeepSpeech2 paper quotes 8% WER on LibriSpeech for DS1 in Table 13, but that's using other data too for training. This discussion gets 21% with only LibriSpeech, which is of course far from 8% mentioned (or SOTA). What numbers would you expect to get by doing the experiment you suggest in colab?

Note that the wav2letter pipeline reports only the CER and gets 15% CER, and does not have a language model.

@discort
Copy link
Contributor Author

discort commented Apr 1, 2021

Ok, thanks for clarifying :) we'll just need to be careful about the phrasing in the description of the model to covey the idea that this might differ from the original implementation.

Proposed model in this PR matches the original DeepSpeech paper.

Sure, they are in the torchaudio repo: wav2letter, wavernn. (I've also updated the comment above with them.)

Is it ok that I add train pipeline in this PR as you did for wav2letter and wavernn?

If the data used for is not freely available to produce the pre-trained weights, then having a colab training pipeline using freely available data would be useful. The DeepSpeech2 paper quotes 8% WER on LibriSpeech for DS1 in Table 13, but that's using other data too for training. This discussion gets 21% with only LibriSpeech, which is of course far from 8% mentioned (or SOTA). What numbers would you expect to get by doing the experiment you suggest in colab?

I found that TPU is impractical so far to train models on CTCLoss so I can't train DeepSpeech using train-clean-100 train-clean-360 train-other-500 because of lack of resources.

The mentioned DS1 implementation says that it achieves 15.80 WER on dev-clean (see results). But it has LSTM instead of RNN (like in original paper) so this MyrtleSoftware/deepspeech has more number of parameters.

In summary, here are thing what I can do:

  • add train/val pipeline
  • train model just on train-clean-100 and provide pre-trained weights

@vincentqb
What do you think about that?

@vincentqb
Copy link
Contributor

Sure, they are in the torchaudio repo: wav2letter, wavernn. (I've also updated the comment above with them.)

Is it ok that I add train pipeline in this PR as you did for wav2letter and wavernn?

This would be a good way of showing the convergence :) (by the way, once we are convinced of the convergence, I'm leaving the door open to only merging the model without the pipeline at the end for now to keep the PR small.)

If the data used for is not freely available to produce the pre-trained weights, then having a colab training pipeline using freely available data would be useful. The DeepSpeech2 paper quotes 8% WER on LibriSpeech for DS1 in Table 13, but that's using other data too for training. This discussion gets 21% with only LibriSpeech, which is of course far from 8% mentioned (or SOTA). What numbers would you expect to get by doing the experiment you suggest in colab?

I found that TPU is impractical so far to train models on CTCLoss so I can't train DeepSpeech using train-clean-100 train-clean-360 train-other-500 because of lack of resources.

Having the pipeline here for train/validation also allows us to validate the pipeline, and potentially run it on our side with the other data you suggest.

The mentioned DS1 implementation says that it achieves 15.80 WER on dev-clean (see results). But it has LSTM instead of RNN (like in original paper) so this MyrtleSoftware/deepspeech has more number of parameters.

In summary, here are thing what I can do:

  • add train/val pipeline
  • train model just on train-clean-100 and provide pre-trained weights

@vincentqb
What do you think about that?

Great!

(I'm reading: train-clean-100 for training and dev-clean for validation :) )

@mthrok
Copy link
Collaborator

mthrok commented Apr 9, 2021

@discort

In summary, here are thing what I can do:

  • add train/val pipeline
  • train model just on train-clean-100 and provide pre-trained weights

If you can share the training script on early stage, feel free to share I can try it in our training cluster as well.

Also if you are writing a training script from scratch, please refer to the distributed training utility and entry point in source separation script. This one achieves a much better code organization and decoupling of training logic and setup.

@discort
Copy link
Contributor Author

discort commented Apr 9, 2021

Just a quick update.
I'm still training the model and it seems it's converged. The problem is in my slow GPU but I can share the temp log just to show you for a 'proof'. You can find a train script here, it's based on examples from torchaudio.

train.log

@mthrok
cc @vincentqb

I'll update script for training DeepSpeech with dist.utility as you proposed. But I won't able to check it manually because I have a single GPU.

So next steps are:

  • Update deepspeech train script by distributed parallel loader, etc.
  • share the full train.log after 10 epochs (?)
  • share the weights. (?)

Sounds good for you?

@discort
Copy link
Contributor Author

discort commented Apr 12, 2021

UPD.
trained on train-clean-100 and train-clean-360. Achieved 0.33 WER on dev-clean. Did not apply regularization such as weight decay or dropout.
pre-trained weights here and some examples from dev-clean:
val.log

In the implementation from Mozilla's DeepSpeech they used 75 epochs, 0.05 dropout and beam decoder.

@vincentqb
Let me know if you need more information.

@vincentqb
Copy link
Contributor

Thanks for sending the script and logs :) Just to clarify: in your log, the training reaches 130.5% WER and validation 33.3% WER, right?

I did a quick run of your script, and get a reduction from about 60 %WER to 33.6 %WER for validation, see log. The training batch goes from around 100 %WER to 21.9 %WER in 14 epochs. I'll let it run a bit see if the validation continue to go down.

Do you have a version of DeepSpeech with dropout as done by Mozilla?

If you want to use a ctc beam search decoder, I'd suggest using awni's. Given that there is no language model here, this should be the same as doing argmax and removing duplicates and blank. Does Mozilla do something different here?

@discort
Copy link
Contributor Author

discort commented Apr 14, 2021

Thanks for sending the script and logs :) Just to clarify: in your log, the training reaches 130.5% WER and validation 33.3% WER, right?

Actually, I calculated WER in training incorrectly, so did not share it intentionally, then I fixed WER and measured WER for validation on dev-clean and got 33% WER in 8 epochs.

Do you have a version of DeepSpeech with dropout as done by Mozilla?

Are you asking about probability coefficients for dropout?

If you want to use a ctc beam search decoder, I'd suggest using awni's. Given that there is no language model here, this should be the same as doing argmax and removing duplicates and blank. Does Mozilla do something different here?

Thanks for sharing Awni's gist. I'll definitely take a look. To be honest I'm not into details how Mozilla's implementation works but at the first look it's based on https://github.com/parlance/ctcdecode so the principle must be the same.

How can I help you to move it forward?

@discort
Copy link
Contributor Author

discort commented Apr 17, 2021

Regarding the dropout:
In the original paper they didn't mention what the exact dropout:

During training we apply a dropout [19] rate between 5% - 10%. We apply dropout in the feedforward layers but not to the recurrent hidden activations.

In the implementation from Mozilla they specify default value 0.05 for dropout just for the first FC layer. I'm not sure what coefficients did they use to get the best WER.

I can tweak the model implementation to accept dropout coefficients in the constructor (if it's needed):
It will be look something like:

class DeepSpeech(nn.Module):
    def __init__(self,
                 in_features: int,
                 hidden_size: int,
                 num_classes: int,
                 fc_dropouts: Tuple[float] = None) -> None:
        super(DeepSpeech, self).__init__()
        if fc_dropouts is None:
            fc_dropouts = (0, 0, 0, 0)

Anyway, I would like to see your thoughts about that.

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I calculated WER in training incorrectly, so did not share it intentionally, then I fixed WER and measured WER for validation on dev-clean and got 33% WER in 8 epochs.

Is the code and log in your previous comment corrected then?

Do you have a version of DeepSpeech with dropout as done by Mozilla?

Are you asking about probability coefficients for dropout?

I was asking for a code to copy paste and run with dropout so I can do a long run with it :) For now, I would not include it in the version for the PR.

When I look at Mozilla, I see layer 1 with 0.05 dropout as you mentioned, but also layer 2, 3, and 6 also defaulting to that value of 0.05.

Thanks for sharing Awni's gist. I'll definitely take a look. To be honest I'm not into details how Mozilla's implementation works but at the first look it's based on https://github.com/parlance/ctcdecode so the principle must be the same.

How can I help you to move it forward?

Thanks for pointing out the decoder that Mozilla uses. We won't include this as part of the PR.

I've had to restart the long run. I'll post the result here.

@vincentqb
Copy link
Contributor

vincentqb commented Apr 21, 2021

I've had to restart the long run. I'll post the result here.

Here's a quick update for fun :) but I'll leave the code running longer.

EDIT: see updated log, with a best 31.8 %WER for validation.

train_loss_log_460
train_wer_460
valid_loss_460
valid_wer_460

@discort
Copy link
Contributor Author

discort commented Apr 21, 2021

Here's a quick update for fun :) but I'll leave the code running longer, see current log.

Thank you so much for sharing it. It's just a pleasure to see this graphs :) (especially when you can't reproduce them manually).


The best validation WER so far (what I've found in the attached log) is 0.318.
Just a couple of ideas of how to achieve lower train loss:

For lower validation WER:
You ran training on train_data_urls=['train-clean-100', 'train-clean-360'].
For instance, in Mozilla's implementation and in this they ran the training also on train-other-500, more training data behaves like a regularizer and should result in lower validation WER.

@vincentqb
Copy link
Contributor

I've ran the model also including train-other-500, see log. Best validation: 28.3 %WER.

train_loss_log_960
train_wer_960
valid_loss_960
valid_wer_960

We can see that the training continues to overfit, which is great in this context.

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this :) I gave a few comments around the organization of the model.

Args:
x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features).
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the docstring :) as mentioned in comment above, we want to keep batch first for the output too, e.g. wav2letter, conv_tasnet, wavernn.

Suggested change
Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes).
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).

Making this change might not be ideal and means the pipeline (1) needs another transpose to be compatible with CTC, (2) this follows the RNN convention though batch_first could be used, (3) might get a performance hit (though this should be measured). Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All suggestions look reasonable and I have already addressed your comments, thank you for sharing it with me.

However, I'm still don't follow your 3rd point. Could you please elaborate it a bit, why keeping batch_first can bring a performance hit?
In the mentioned example you are using simple DataParallel and here is my understanding of how it works (on 2 devices for simplicity):

  • the model is copied into two devices
  • dataloader splits a batch into 2 chunks
  • then it feeds each model by each chunk during forward pass
  • then, using the outputs of models, each device calculates ctc loss, the loss values are copied into single tensor (on RAM) and reduction is performed on cpu.
  • on a backward pass, the gradients are computed on 2 devices and then they are 'all-reduce'd on cpu.

If I understood the process correctly, what's the intuition of keeping batch_first in the model's output?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In summary: let's have batch first for the output too for readability and consistency with the other models :)

Could you please elaborate it a bit, why keeping batch_first can bring a performance hit?

The question about performance that I was raising was whether the forward method in RNN is "faster" when batch is first or not. For instance, this small user test here does not seem to find a significant difference in speed. Based on this small test, let's not worry about this here then.

In the mentioned example you are using simple DataParallel and here is my understanding of how it works (on 2 devices for simplicity):

DataParallel applies only to the computation of the model:

  • the model is copied to the two devices
  • the batch is split in to along the first dimension and sent to each devices
  • the forward is computed on each device
  • the output tensors are concatenated along the first dimension
  • the loss function is then computed on that new tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed explanation. Let me know if anything is wrong.

@discort discort requested a review from vincentqb May 1, 2021 09:28
Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I just did a re-run with the latest version, and got a slightly better best validation of 0.259 at epoch 22.

I'll push a commit to change the names of the parameter to be more consistent with WaveRNN and the convention described in the README, wait for the tests to run, and merge this PR. Thanks for working on this! :)

@vincentqb
Copy link
Contributor

updated the naming convention, waiting for the tests to run :)

@vincentqb vincentqb force-pushed the add_deepspeech_model branch from 5f8a78d to 7b59009 Compare May 11, 2021 18:40
@vincentqb
Copy link
Contributor

(rebased)

@vincentqb
Copy link
Contributor

vincentqb commented May 11, 2021

(I confirm the failures are unrelated to this PR. In particular CodeQL failure is also occurring on this one.)

@vincentqb vincentqb merged commit 1f13667 into pytorch:master May 11, 2021
@vincentqb
Copy link
Contributor

vincentqb commented May 11, 2021

Merged, thanks again for working on this!

carolineechen pushed a commit to carolineechen/audio that referenced this pull request May 12, 2021
Co-authored-by: Vincent Quenneville-Belair <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants