File tree Expand file tree Collapse file tree 3 files changed +26
-14
lines changed
intel_extension_for_pytorch Expand file tree Collapse file tree 3 files changed +26
-14
lines changed Original file line number Diff line number Diff line change @@ -101,18 +101,22 @@ def may_import_deepspeed_modules():
101101 try :
102102 # import deepspeed in a global space will raise circular import error
103103 # intel-extension-for-deepspeed imports both IPEX and deepspeed
104- from deepspeed .module_inject .layers import LinearAllreduce , LinearLayer
105-
106- ds_layers = [ LinearAllreduce , LinearLayer ]
107-
108- # TODO: remove this logic once deepspeed LmHeadLinearAllreduce change has been upstream-ed.
109- try :
110- from deepspeed . module_inject . layers import LmHeadLinearAllreduce
104+ from deepspeed .module_inject .layers import (
105+ LinearAllreduce ,
106+ LinearLayer ,
107+ LmHeadLinearAllreduce ,
108+ fused_LinearLayer ,
109+ GateUpPack_LinearLayer ,
110+ )
111111
112- ds_layers .append (LmHeadLinearAllreduce )
113- return ds_layers
114- except ImportError :
115- return ds_layers
112+ ds_layers = [
113+ LinearAllreduce ,
114+ LinearLayer ,
115+ LmHeadLinearAllreduce ,
116+ fused_LinearLayer ,
117+ GateUpPack_LinearLayer ,
118+ ]
119+ return ds_layers
116120 except ImportError :
117121 return None
118122
Original file line number Diff line number Diff line change @@ -292,11 +292,19 @@ def _convert_woq_with_low_precision_checkpoint(
292292
293293 deepspeed_modules = may_import_deepspeed_modules ()
294294 if deepspeed_modules is not None :
295- LinearAllreduce , LinearLayer , LmHeadLinearAllreduce = deepspeed_modules [:]
295+ (
296+ LinearAllreduce ,
297+ LinearLayer ,
298+ LmHeadLinearAllreduce ,
299+ fused_LinearLayer ,
300+ GateUpPack_LinearLayer ,
301+ ) = deepspeed_modules
296302 q_op_map .update (
297303 {
298304 LinearAllreduce : IpexWoqLinearAllreduce ,
299305 LinearLayer : WeightOnlyQuantizedLinear ,
306+ fused_LinearLayer : WeightOnlyQuantizedLinear ,
307+ GateUpPack_LinearLayer : WeightOnlyQuantizedLinear ,
300308 }
301309 )
302310
Original file line number Diff line number Diff line change @@ -197,7 +197,7 @@ def _get_ds_model(self, m_linear):
197197 def test_ipex_optimize (self ):
198198 deepspeed_modules = may_import_deepspeed_modules ()
199199 if deepspeed_modules is not None :
200- LinearAllreduce , LinearLayer , LmHeadLinearAllreduce = deepspeed_modules
200+ LinearAllreduce , LinearLayer , LmHeadLinearAllreduce = deepspeed_modules [: 3 ]
201201
202202 x = torch .randn (2 , 3 , 64 )
203203 m_linear = DeepSpeedTestM (MyLmHeadModel ).eval ()
@@ -241,7 +241,7 @@ def _test_quantization(
241241 ):
242242 deepspeed_modules = may_import_deepspeed_modules ()
243243 if deepspeed_modules is not None :
244- LinearAllreduce , LinearLayer , LmHeadLinearAllreduce = deepspeed_modules
244+ LinearAllreduce , LinearLayer , LmHeadLinearAllreduce = deepspeed_modules [: 3 ]
245245
246246 x = torch .randn (2 , 3 , 64 )
247247 m_linear = DeepSpeedTestM (MyLmHeadModel ).eval ()
You can’t perform that action at this time.
0 commit comments