[v1.x] Onnx Support for Transformer#20048
Conversation
|
Hey @Zha0q1 , Thanks for submitting the PR
CI supported jobs: [centos-gpu, unix-gpu, miscellaneous, edge, website, windows-cpu, unix-cpu, windows-gpu, sanity, centos-cpu, clang] Note: |
|
|
||
| assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03) | ||
|
|
||
| def verify_one_step_ahead_decoder(): |
There was a problem hiding this comment.
@sxjscience would you help take a quick look at this func thanks!
|
|
||
| batch = 7 | ||
| seq_length = 16 | ||
| C_in = 512 |
There was a problem hiding this comment.
What are C_in and C_out? Should we also test when C_in != C_out?
There was a problem hiding this comment.
You can refer to this file https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py for C_in and C_out. Those are defined in the pretrained model thus we need to set it the same as in the pretrained model
| prefix = "%s/one_step_ahead_decoder" %tmp_path | ||
|
|
||
| # the input data order | ||
| perm = [2, 0, 1] |
There was a problem hiding this comment.
Could we put the correct order when instantiating the list instead of using perm?
There was a problem hiding this comment.
I used a perm list so that the actual in_shapes an in_types list can have the same order as passed in the native model. It's just the converted onnx takes them in a different order some how. I think this is more consistent, what do you think?
| seq_length = 16 | ||
| C_in = 512 | ||
| C_out = 512 | ||
| src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') |
There was a problem hiding this comment.
Curious, does not src need to be int type?
There was a problem hiding this comment.
No it's float in the original mxnet model too. This should not matter I think because the operator will apply ceiling/flooring
|
LGTM, thanks |
This pr adds support for the pretrained
transformer_en_de_512model.We are breaking the transformer into encoder, decoder, embedding and projection and test each part seperately
To get one_step_ahead_decoder to work the seq_len is dynamic