[MXNET-798] Fix the dtype cast from non float32 in Gradient computation#12290
[MXNET-798] Fix the dtype cast from non float32 in Gradient computation#12290anirudh2290 merged 14 commits intoapache:masterfrom
Conversation
|
@eric-haibin-lin @piiswrong @haojin2 I will appreciate your review. |
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_infer_multiout_op() No newline at end of file |
There was a problem hiding this comment.
I think this should go to something like test_operator.py instead of creating a separate file for it? And, please see https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L7017-L7018 for how to use nosetests.
| test64.backward() | ||
| assert_almost_equal(data64.grad.asnumpy().all(), data32.grad.asnumpy().all()) | ||
|
|
||
|
|
There was a problem hiding this comment.
I think this should go to something like test_operator.py instead of creating a separate file for it? And, please see https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L7017-L7018 for how to use nosetests.
There was a problem hiding this comment.
This is not to test the functionality of the operator but a general type casting issue for all multioutput operators. I inclined to add it in the infer type tests but would like to hear more suggestions.
There was a problem hiding this comment.
Changed test to run nose runmodule
|
Change to [WIP] to fix some platform dependent unit test failure. |
| } | ||
| auto finfer = finfer_shape.get(inode.source->op(), fdefault); | ||
| auto finfer = (inode.source->op() == Op::Get("_zeros")) ? fdefault : | ||
| finfer_shape.get(inode.source->op(), fdefault); |
There was a problem hiding this comment.
Are you sure about this? This affects all _zero ops, not just for the case you mentioned.
There was a problem hiding this comment.
You are right, this is breaking some unit test (however, due to unittest of master branch is broken in MacOS, I wan't able to verify before checkin). I have changed the PR to WIP.
|
@eric-haibin-lin Please review this new implementation. Thanks for your suggestion! |
|
What's up with the build? |
|
@eric-haibin-lin Not sure exactly. An earlier build passed dcc5f78). After I renamed some variables the build on ARM7 failed. I can submit an empty change to trigger the build again. |
| with autograd.record(): | ||
| test64 = test_func(data64) | ||
| test64.backward() | ||
| assert_almost_equal(data64.grad.asnumpy().all(), data32.grad.asnumpy().all()) |
There was a problem hiding this comment.
can you set rtol and atol to some bigger value than default here ?
There was a problem hiding this comment.
Why increase the rtol and atol if the unit test can pass with the default one?
There was a problem hiding this comment.
This can be flaky. you are comparing a float32 numpy to a float64 numpy and the atol and rtol defaults are small.
|
Also,maybe we should add zeros to APIs that may be good to break for 2.0 #9686 |
|
@anirudh2290 The _zeros_without_dtype operator is a private operator used only in building nnvm graph. It is not meant to be exposed to users. |
|
@apeforest what i meant is we can change the dtype default to -1 for zeros operator for 2.0. |
|
@anirudh2290 Thanks for the clarification. I have increased atol and rtol values in unit test. As to changing the dtype default to -1 for zeros, I think it is not related to this PR and may cause a backward compatibility issue with old models. Therefore, I would prefer doing that in a separate PR. Please let me know what you think. Thanks. |
|
Not suggesting to do it in this PR. Just wanted to document it in the APIs to break for 2.0 and we can do it before 2.0 release. |
…on (apache#12290) * Fix the dtype mismatch in derived _zeros node * Add unittest for infer dtype * Add one more unit test * Add nose runmodule * Add a zero operator with no default dtype * Rename variables * fix a bug: rename operator for gpu * Increase atol and rtol to avoid flakiness
Description
This PR fixes the issues #9067 and #8799 where gradient computation for operators with multiple output fails in ndarray if the dtype is not float32.
The root cause of the issue is that a _zeros operator was added for the other don't care output. The _zeros operator uses float32 dtype by default and it will cause conflict if the dtype in ndarray is not float32. My solution is to create a new _zeros_without_dtype operator that does not take any default dtype and use it to replace the _zeros operator in the computation graph. This change solves the dtype conflict problem and should be backward compatible.
A unit test is added to test this fix.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments