@@ -7,6 +7,7 @@ namespace cpu {
77
88IPEX_DEFINE_DISPATCH (mixtral_moe_tpp_kernel_stub);
99IPEX_DEFINE_DISPATCH (mixtral_moe_woq_kernel_stub);
10+ IPEX_DEFINE_DISPATCH (deepseek_moe_woq_kernel_stub);
1011IPEX_DEFINE_DISPATCH (mixtral_moe_kernel_stub);
1112
1213at::Tensor mixtral_moe_tpp (
@@ -38,6 +39,41 @@ at::Tensor mixtral_moe_tpp(
3839 is_distributed);
3940}
4041
42+ at::Tensor deepseek_moe_tpp (
43+ const at::Tensor& hidden_states,
44+ const at::Tensor& expert_mask,
45+ const std::vector<at::Tensor>& gate_wei,
46+ const std::vector<at::Tensor>& up_wei,
47+ const std::vector<at::Tensor>& down_wei,
48+ bool tpp_fallback,
49+ const at::Tensor& routing_weights,
50+ at::Tensor& output,
51+ bool is_distributed) {
52+ RECORD_FUNCTION (" ipex::deepseek_moe_tpp" , c10::ArrayRef<c10::IValue>({}));
53+
54+ int num_experts = gate_wei.size ();
55+ for (auto i = 0 ; i < num_experts; i++) {
56+ auto non_zero = expert_mask[i].nonzero ();
57+ if (non_zero.sizes ()[0 ] == 0 )
58+ continue ;
59+ auto idx = non_zero.select (1 , 0 );
60+ auto top_x = non_zero.select (1 , 1 );
61+ output = mixtral_moe_tpp_kernel_stub (
62+ kCPU ,
63+ hidden_states,
64+ top_x,
65+ idx,
66+ gate_wei[i],
67+ up_wei[i],
68+ down_wei[i],
69+ tpp_fallback,
70+ routing_weights,
71+ output,
72+ is_distributed);
73+ }
74+ return output;
75+ }
76+
4177at::Tensor mixtral_moe (
4278 const at::Tensor& hidden_states,
4379 const at::Tensor& top_x,
@@ -72,6 +108,87 @@ at::Tensor mixtral_moe(
72108 output,
73109 is_distributed);
74110}
111+
112+ at::Tensor deepseek_moe (
113+ const at::Tensor& hidden_states,
114+ const at::Tensor& expert_mask,
115+ const std::vector<at::Tensor>& gate_wei,
116+ const std::vector<c10::intrusive_ptr<LinearOpContext>>& gate_op_ctx,
117+ const std::vector<at::Tensor>& up_wei,
118+ const std::vector<c10::intrusive_ptr<LinearOpContext>>& up_op_ctx,
119+ const std::vector<at::Tensor>& down_wei,
120+ const std::vector<c10::intrusive_ptr<LinearOpContext>>& down_op_ctx,
121+ const at::Tensor& routing_weights,
122+ at::Tensor& output,
123+ bool is_distributed) {
124+ RECORD_FUNCTION (" ipex::deepseek_moe" , c10::ArrayRef<c10::IValue>({}));
125+
126+ int num_experts = gate_wei.size ();
127+ for (auto i = 0 ; i < num_experts; i++) {
128+ auto non_zero = expert_mask[i].nonzero ();
129+ if (non_zero.sizes ()[0 ] == 0 )
130+ continue ;
131+ auto idx = non_zero.select (1 , 0 );
132+ auto top_x = non_zero.select (1 , 1 );
133+
134+ output = mixtral_moe_kernel_stub (
135+ kCPU ,
136+ hidden_states,
137+ top_x,
138+ idx,
139+ gate_wei[i],
140+ gate_op_ctx[i]->get_data_handle (),
141+ up_wei[i],
142+ up_op_ctx[i]->get_data_handle (),
143+ down_wei[i],
144+ down_op_ctx[i]->get_data_handle (),
145+ true ,
146+ routing_weights,
147+ output,
148+ is_distributed);
149+ }
150+ return output;
151+ }
152+
153+ at::Tensor deepseek_moe_mkl (
154+ const at::Tensor& hidden_states,
155+ const at::Tensor& expert_mask,
156+ const std::vector<at::Tensor>& gate_wei,
157+ const std::vector<c10::intrusive_ptr<MKLOpContext>>& gate_op_ctx,
158+ const std::vector<at::Tensor>& up_wei,
159+ const std::vector<c10::intrusive_ptr<MKLOpContext>>& up_op_ctx,
160+ const std::vector<at::Tensor>& down_wei,
161+ const std::vector<c10::intrusive_ptr<MKLOpContext>>& down_op_ctx,
162+ const at::Tensor& routing_weights,
163+ at::Tensor& output,
164+ bool is_distributed) {
165+ RECORD_FUNCTION (" ipex::deepseek_moe_mkl" , c10::ArrayRef<c10::IValue>({}));
166+
167+ int num_experts = gate_wei.size ();
168+ for (auto i = 0 ; i < num_experts; i++) {
169+ auto non_zero = expert_mask[i].nonzero ();
170+ if (non_zero.sizes ()[0 ] == 0 )
171+ continue ;
172+ auto idx = non_zero.select (1 , 0 );
173+ auto top_x = non_zero.select (1 , 1 );
174+ output = mixtral_moe_kernel_stub (
175+ kCPU ,
176+ hidden_states,
177+ top_x,
178+ idx,
179+ gate_wei[i],
180+ gate_op_ctx[i]->get_data_handle (),
181+ up_wei[i],
182+ up_op_ctx[i]->get_data_handle (),
183+ down_wei[i],
184+ down_op_ctx[i]->get_data_handle (),
185+ false ,
186+ routing_weights,
187+ output,
188+ is_distributed);
189+ }
190+ return output;
191+ }
75192at::Tensor mixtral_moe_woq (
76193 const at::Tensor& hidden_states,
77194 const at::Tensor& top_x,
@@ -98,6 +215,38 @@ at::Tensor mixtral_moe_woq(
98215 output,
99216 is_distributed);
100217}
218+ at::Tensor deepseek_moe_woq (
219+ const at::Tensor& hidden_states,
220+ const at::Tensor& expert_mask,
221+ const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& gate_ctx,
222+ const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& up_ctx,
223+ const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& down_ctx,
224+ const at::Tensor& routing_weights,
225+ at::Tensor& output,
226+ bool is_distributed) {
227+ RECORD_FUNCTION (" ipex::deepseek_moe_woq" , c10::ArrayRef<c10::IValue>({}));
228+
229+ int num_experts = gate_ctx.size ();
230+ for (auto i = 0 ; i < num_experts; i++) {
231+ auto non_zero = expert_mask[i].nonzero ();
232+ if (non_zero.sizes ()[0 ] == 0 )
233+ continue ;
234+ auto idx = non_zero.select (1 , 0 );
235+ auto top_x = non_zero.select (1 , 1 );
236+ output = mixtral_moe_woq_kernel_stub (
237+ kCPU ,
238+ hidden_states,
239+ top_x,
240+ idx,
241+ gate_ctx[i]->get_data_handle (),
242+ up_ctx[i]->get_data_handle (),
243+ down_ctx[i]->get_data_handle (),
244+ routing_weights,
245+ output,
246+ is_distributed);
247+ }
248+ return output;
249+ }
101250} // namespace cpu
102251} // namespace torch_ipex
103252
@@ -112,17 +261,53 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
112261 " mixtral_moe_tpp" ,
113262 c10::DispatchKey::CPU,
114263 torch_ipex::cpu::mixtral_moe_tpp);
264+ m.def (
265+ " deepseek_moe_tpp(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
266+ Tensor[] up_wei, Tensor[] down_wei, bool tpp_fallback, Tensor routing_weights, \
267+ Tensor output, bool is_distributed) -> Tensor" );
268+ m.impl (
269+ " deepseek_moe_tpp" ,
270+ c10::DispatchKey::CPU,
271+ torch_ipex::cpu::deepseek_moe_tpp);
115272 m.def (
116273 " mixtral_moe(Tensor hidden_states, Tensor top_x, Tensor idx, Tensor gate_wei, \
117274 Tensor gate_op_ctx, Tensor up_wei, Tensor up_op_ctx, Tensor down_wei, \
118275 Tensor down_op_ctx, bool use_dnnl, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
119276 m.impl (" mixtral_moe" , c10::DispatchKey::CPU, torch_ipex::cpu::mixtral_moe);
277+ m.def (
278+ " deepseek_moe(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
279+ __torch__.torch.classes.ipex_prepack.LinearOpContext[] gate_op_ctx, Tensor[] up_wei, \
280+ __torch__.torch.classes.ipex_prepack.LinearOpContext[] up_op_ctx, Tensor[] down_wei, \
281+ __torch__.torch.classes.ipex_prepack.LinearOpContext[] down_op_ctx, Tensor routing_weights, \
282+ Tensor output, bool is_distributed) -> Tensor" );
283+ m.impl (" deepseek_moe" , c10::DispatchKey::CPU, torch_ipex::cpu::deepseek_moe);
284+ m.def (
285+ " deepseek_moe_mkl(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
286+ __torch__.torch.classes.ipex_prepack.MKLOpContext[] gate_op_ctx, Tensor[] up_wei, \
287+ __torch__.torch.classes.ipex_prepack.MKLOpContext[] up_op_ctx, \
288+ Tensor[] down_wei, __torch__.torch.classes.ipex_prepack.MKLOpContext[] down_op_ctx, \
289+ Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
290+ m.impl (
291+ " deepseek_moe_mkl" ,
292+ c10::DispatchKey::CPU,
293+ torch_ipex::cpu::deepseek_moe_mkl);
120294 m.def (
121295 " mixtral_moe_woq(Tensor hidden_states, Tensor top_x, Tensor idx, Tensor gate_wei, \
122296 Tensor up_wei, Tensor down_wei, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
123297 m.impl (
124298 " mixtral_moe_woq" ,
125299 c10::DispatchKey::CPU,
126300 torch_ipex::cpu::mixtral_moe_woq);
301+ m.def (
302+ " deepseek_moe_woq(Tensor hidden_states, Tensor expert_mask, \
303+ __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] gate_ctx, \
304+ __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] up_ctx, \
305+ __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] down_ctx, \
306+ Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
307+
308+ m.impl (
309+ " deepseek_moe_woq" ,
310+ c10::DispatchKey::CPU,
311+ torch_ipex::cpu::deepseek_moe_woq);
127312}
128313} // namespace
0 commit comments