-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add group_norm_v2 #1887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add group_norm_v2 #1887
Conversation
|
Hi @crcrpar, could you also help review this PR? |
| constexpr auto params = compute_gn_params<T, false, false, HW, G, CPG, LB_N, RUNTIME_CUDA_ARCH, LB_SM_COUNT, EFFECTIVE_CUDA_ARCH, SM_MARGIN>(); | ||
| constexpr int BLOCK_DIM_X = std::get<0>(params); | ||
| constexpr int C_PER_BLOCK = std::get<1>(params); | ||
| constexpr int ROWS_PER_BLOCK = std::get<2>(params); | ||
| constexpr int VEC_ELEMS = std::get<3>(params); | ||
| constexpr bool LOAD_TWICE = std::get<4>(params); | ||
| constexpr int BLOCKS_PER_SM = std::get<5>(params); | ||
| constexpr bool HARDWARE_CLUSTER = std::get<6>(params); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(ah, structured binding cannot be used with constexpr...)
crcrpar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some test cases, e.g.
- check numeric and functionality of v2
- error is expectedly thrown for invalid inputs
?
| if bare_metal_version >= Version("12.8"): | ||
| arch_flags = ["-gencode=arch=compute_100,code=sm_100"] | ||
| else: | ||
| arch_flags = ["-gencode=arch=compute_90,code=compute_90"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Looking at this conditions, it is a Blackwell exclusive. Is it supposed to work on Hopper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works on Hopper if we add it to arch_flags and DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT. However, I think we need a smarter way to determine template args to support more GPUs, instead of exhausting SM_COUNT.
|
Has "check numeric and functionality of v2" tests been performed? |
Yes, tests were added and passed. |
|
@nWEIdia yes. |
GroupNorm v2 outperforms the original GroupNorm extension by utilizing coalesced memory access, tested on B200 with a set of shapes.