Skip to content

Commit b703072

Browse files
committed
rollback to single rnn in spell
1 parent 8959ab1 commit b703072

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

asr/models/las/network.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def forward(self, s, h, len_mask=None):
180180
class Speller(nn.Module):
181181

182182
def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None, eos=None,
183-
rnn_type=nn.LSTM, rnn_hidden_size=512, rnn_num_layers=[1, 1],
183+
rnn_type=nn.LSTM, rnn_hidden_size=512, rnn_num_layers=2,
184184
apply_attend_proj=False, proj_hidden_size=256, num_attend_heads=1,
185185
masked_attend=True):
186186
super().__init__()
@@ -196,11 +196,9 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
196196
Hs, Hc, Hy = rnn_hidden_size, listen_vec_size, label_vec_size
197197

198198
self.rnn_num_layers = rnn_num_layers
199-
self.rnns1 = rnn_type(input_size=(Hy + Hc), hidden_size=Hs, num_layers=rnn_num_layers[0],
200-
bias=True, bidirectional=False, batch_first=True)
201-
self.norm = nn.LayerNorm(Hs)
202-
self.rnns2 = rnn_type(input_size=(Hs + Hc), hidden_size=Hs, num_layers=rnn_num_layers[1],
199+
self.rnns = rnn_type(input_size=(Hy + Hc), hidden_size=Hs, num_layers=rnn_num_layers,
203200
bias=True, bidirectional=False, batch_first=True)
201+
self.norm = nn.LayerNorm(Hs, elementwise_affine=False)
204202

205203
self.attention = Attention(state_vec_size=Hs, listen_vec_size=Hc,
206204
apply_proj=apply_attend_proj, proj_hidden_size=proj_hidden_size,
@@ -209,7 +207,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
209207
self.masked_attend = masked_attend
210208

211209
self.chardist = nn.Sequential(OrderedDict([
212-
('fc1', nn.Linear(Hs, 128, bias=True)),
210+
('fc1', nn.Linear(Hs + Hc, 128, bias=True)),
213211
('fc2', nn.Linear(128, label_vec_size, bias=False)),
214212
]))
215213

@@ -227,7 +225,7 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
227225
sos = int2onehot(h.new_full((batch_size, 1), self.sos), num_classes=self.label_vec_size).float()
228226
eos = int2onehot(h.new_full((batch_size, 1), self.eos), num_classes=self.label_vec_size).float()
229227

230-
hidden1, hidden2 = None, None
228+
hidden = None
231229
y_hats = list()
232230
attentions = list()
233231

@@ -237,11 +235,10 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
237235
y_hats_seq_lens = torch.ones((batch_size, ), dtype=torch.int) * self.max_seq_lens
238236

239237
for t in range(self.max_seq_lens):
240-
s, hidden1 = self.rnns1(x, hidden1)
238+
s, hidden = self.rnns(x, hidden)
241239
s = self.norm(s)
242240
c, a = self.attention(s, h, in_mask)
243-
s, hidden2 = self.rnns2(torch.cat([s, c], dim=-1), hidden2)
244-
y_hat = self.chardist(s)
241+
y_hat = self.chardist(torch.cat([s, c], dim=-1))
245242
y_hat = self.softmax(y_hat)
246243

247244
y_hats.append(y_hat)
@@ -337,7 +334,7 @@ def __init__(self, label_vec_size=p.NUM_CTC_LABELS, listen_vec_size=256,
337334

338335
self.spell = Speller(listen_vec_size=listen_vec_size, label_vec_size=self.label_vec_size,
339336
sos=self.sos, eos=self.eos, max_seq_lens=256,
340-
rnn_hidden_size=state_vec_size, rnn_num_layers=[1, 1],
337+
rnn_hidden_size=state_vec_size, rnn_num_layers=2,
341338
apply_attend_proj=True, proj_hidden_size=128, num_attend_heads=num_attend_heads)
342339

343340
self.attentions = None

0 commit comments

Comments
 (0)