[WIP] Implement 2nd pass training using 1-best decoding results from the 1st pass network#198
[WIP] Implement 2nd pass training using 1-best decoding results from the 1st pass network#198csukuangfj wants to merge 15 commits intok2-fsa:masterfrom
Conversation
snowfall/models/second_pass_model.py
Outdated
|
|
||
| # now x2 is (B, T, F) | ||
|
|
||
| x_concat = torch.cat((padded_acoustics, x2), dim=-1) |
There was a problem hiding this comment.
TODO(fangjun): Use cross attention here
- query: x2
- key and value: padded_acoustics
and masked self-attention
- key, query, and value: x2
| @@ -0,0 +1,484 @@ | |||
| #!/usr/bin/env python3 | |||
There was a problem hiding this comment.
common2.py is the same as common.py, except that it has some code supporting
the second pass model. To avoid conflicts with the master, a new file is used.
The same goes for the following xxx2.py files, e.g., lm_rescore2.py, mmi2.py.
| import k2 | ||
|
|
||
|
|
||
| class Nbest(object): |
There was a problem hiding this comment.
This file implements the Nbest class proposed in
#232 (comment)
Please have a review if it matches the proposal.
There was a problem hiding this comment.
That's great! Yes it looks like what I had in mind.
I assume you would separate it from this PR though? Or maybe even submit it to k2? Since there's a lot going on here.
There was a problem hiding this comment.
Will move it to k2.
It implements #106 (comment)
The training objf is decreasing and seems to be converging. Will post the decoding results later.