-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathds2_ctc_loss_layer.py
41 lines (33 loc) · 1.29 KB
/
ds2_ctc_loss_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# --------------------------------------------------------
# Deep Speech 2 Caffe Implementation
# Written by Tian, Feng <[email protected]>
# --------------------------------------------------------
"""The data layer used during training to train a DS2 network.
DS2CtcLossLayer implements a Caffe Python layer.
"""
import caffe
import numpy as np
import ctc
class DS2CtcLossLayer(caffe.Layer):
"""DeepSpeech2 transpose layer used for training."""
def setup(self, bottom, top):
"""Setup the DS2CtcLossLayer."""
assert len(bottom) == 4
top[0].reshape(1)
self.ctcloss = ctc.CTCLoss();
def forward(self, bottom, top):
"""Get blobs and copy them into this layer's top blob vector."""
acts = bottom[0].data
targets = bottom[1].data
input_percentages = bottom[2].data
target_sizes = bottom[3].data
sizes = input_percentages * acts.shape[0]
self.ctcloss.ctc_loss(acts, targets, sizes.astype(np.int, copy=False), target_sizes)
top[0].data[0] = self.ctcloss.costs.sum() / 20
def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
bottom[0].diff[...] = self.ctcloss.grads.astype(np.float32, copy=False)
pass
def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass