@@ -180,7 +180,7 @@ def forward(self, s, h, len_mask=None):
180180class Speller (nn .Module ):
181181
182182 def __init__ (self , listen_vec_size , label_vec_size , max_seq_lens = 256 , sos = None , eos = None ,
183- rnn_type = nn .LSTM , rnn_hidden_size = 512 , rnn_num_layers = [ 1 , 1 ] ,
183+ rnn_type = nn .LSTM , rnn_hidden_size = 512 , rnn_num_layers = 2 ,
184184 apply_attend_proj = False , proj_hidden_size = 256 , num_attend_heads = 1 ,
185185 masked_attend = True ):
186186 super ().__init__ ()
@@ -196,11 +196,9 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
196196 Hs , Hc , Hy = rnn_hidden_size , listen_vec_size , label_vec_size
197197
198198 self .rnn_num_layers = rnn_num_layers
199- self .rnns1 = rnn_type (input_size = (Hy + Hc ), hidden_size = Hs , num_layers = rnn_num_layers [0 ],
200- bias = True , bidirectional = False , batch_first = True )
201- self .norm = nn .LayerNorm (Hs )
202- self .rnns2 = rnn_type (input_size = (Hs + Hc ), hidden_size = Hs , num_layers = rnn_num_layers [1 ],
199+ self .rnns = rnn_type (input_size = (Hy + Hc ), hidden_size = Hs , num_layers = rnn_num_layers ,
203200 bias = True , bidirectional = False , batch_first = True )
201+ self .norm = nn .LayerNorm (Hs , elementwise_affine = False )
204202
205203 self .attention = Attention (state_vec_size = Hs , listen_vec_size = Hc ,
206204 apply_proj = apply_attend_proj , proj_hidden_size = proj_hidden_size ,
@@ -209,7 +207,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
209207 self .masked_attend = masked_attend
210208
211209 self .chardist = nn .Sequential (OrderedDict ([
212- ('fc1' , nn .Linear (Hs , 128 , bias = True )),
210+ ('fc1' , nn .Linear (Hs + Hc , 128 , bias = True )),
213211 ('fc2' , nn .Linear (128 , label_vec_size , bias = False )),
214212 ]))
215213
@@ -227,7 +225,7 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
227225 sos = int2onehot (h .new_full ((batch_size , 1 ), self .sos ), num_classes = self .label_vec_size ).float ()
228226 eos = int2onehot (h .new_full ((batch_size , 1 ), self .eos ), num_classes = self .label_vec_size ).float ()
229227
230- hidden1 , hidden2 = None , None
228+ hidden = None
231229 y_hats = list ()
232230 attentions = list ()
233231
@@ -237,11 +235,10 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
237235 y_hats_seq_lens = torch .ones ((batch_size , ), dtype = torch .int ) * self .max_seq_lens
238236
239237 for t in range (self .max_seq_lens ):
240- s , hidden1 = self .rnns1 (x , hidden1 )
238+ s , hidden = self .rnns (x , hidden )
241239 s = self .norm (s )
242240 c , a = self .attention (s , h , in_mask )
243- s , hidden2 = self .rnns2 (torch .cat ([s , c ], dim = - 1 ), hidden2 )
244- y_hat = self .chardist (s )
241+ y_hat = self .chardist (torch .cat ([s , c ], dim = - 1 ))
245242 y_hat = self .softmax (y_hat )
246243
247244 y_hats .append (y_hat )
@@ -337,7 +334,7 @@ def __init__(self, label_vec_size=p.NUM_CTC_LABELS, listen_vec_size=256,
337334
338335 self .spell = Speller (listen_vec_size = listen_vec_size , label_vec_size = self .label_vec_size ,
339336 sos = self .sos , eos = self .eos , max_seq_lens = 256 ,
340- rnn_hidden_size = state_vec_size , rnn_num_layers = [ 1 , 1 ] ,
337+ rnn_hidden_size = state_vec_size , rnn_num_layers = 2 ,
341338 apply_attend_proj = True , proj_hidden_size = 128 , num_attend_heads = num_attend_heads )
342339
343340 self .attentions = None
0 commit comments