Conversation
|
Hey @anko-intel , Thanks for submitting the PR
CI supported jobs: [clang, unix-cpu, unix-gpu, centos-gpu, windows-cpu, windows-gpu, miscellaneous, website, sanity, edge, centos-cpu] Note: |
OneDNN doesn't support float16 format so fallback to standard implementation is needed. It fixes issue 19631.
4d22ab7 to
aed0619
Compare
|
@rongzha1 - could you review? |
src/operator/tensor/amp_cast.cc
Outdated
| mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); | ||
| for (size_t i = 0; i < i_ndim; i++) { | ||
| i_dims[i] = static_cast<int>(data.shape()[i]); | ||
| if (data.dtype() != mshadow::kFloat16) { |
There was a problem hiding this comment.
shall we add isValidMKLDNNDataType() to check whether it is supported by mkldnn? mshadow has so many data types and some of them are not supported. https://github.com/apache/incubator-mxnet/blob/64f737cdd59fe88d2c5b479f25d011c5156b6a8a/3rdparty/mshadow/mshadow/base.h#L364:3
There was a problem hiding this comment.
I considered that. If created isValidMKLDNNDataType() function could be used in many places like MKLDNNStorageType() for FInferStorageType it makes sense. But in this particular situation, amp_cast operator only accept 3 float types (see https://github.com/apache/incubator-mxnet/blob/v1.x/src/operator/tensor/amp_cast.h#L70 ) so I just excluded float16 as not supported in MKLDNN.
src/operator/tensor/amp_cast.cc
Outdated
| mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); | ||
| for (size_t i = 0; i < i_ndim; i++) { | ||
| i_dims[i] = static_cast<int>(data.shape()[i]); | ||
| if (data.dtype() != mshadow::kFloat16) { |
|
@PatricZhao, @szha could you review and merge if everything is ok? |
szha
left a comment
There was a problem hiding this comment.
Thanks for the fix! Could you add a test for verification?
|
@mxnet-bot run ci [centos-cpu, unix-gpu] |
|
Jenkins CI successfully triggered : [centos-cpu, unix-gpu] |
* Fix AmpCast for float16 OneDNN doesn't support float16 format so fallback to standard implementation is needed. It fixes issue 19631. * Enable amp_cast test for float16 on CPU context
Description
OneDNN doesn't support float16 format, so fallback to standard
implementation is needed.
It fixes issue #19631.
Checklist
Essentials
Comments