@@ -241,6 +241,42 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li
241241 not_missing_keys .append (missing )
242242 return [k for k in missing_keys if k not in not_missing_keys ]
243243
244+ def update_tp_plan (self , config ):
245+ text_plan = {
246+ "layers.*.self_attn.q_proj.weight" : "local_colwise" ,
247+ "layers.*.self_attn.q_proj.weight_scale" : "local_colwise" ,
248+ "layers.*.self_attn.k_proj.weight" : "local_colwise" ,
249+ "layers.*.self_attn.k_proj.weight_scale" : "local_colwise" ,
250+ "layers.*.self_attn.v_proj.weight" : "local_colwise" ,
251+ "layers.*.self_attn.v_proj.weight_scale" : "local_colwise" ,
252+ "layers.*.self_attn.o_proj.weight" : "local_rowwise" ,
253+ "layers.*.self_attn" : "gather" ,
254+ "layers.*.input_layernorm.weight" : "sequence_parallel" ,
255+ "layers.*.post_attention_layernorm.weight" : "sequence_parallel" ,
256+ "norm.weight" : "sequence_parallel" ,
257+ "layers.*.feed_forward.shared_expert.gate_proj.weight" : "local_colwise" ,
258+ "layers.*.feed_forward.shared_expert.gate_proj.weight_scale" : "local_colwise" ,
259+ "layers.*.feed_forward.shared_expert.up_proj.weight" : "local_colwise" ,
260+ "layers.*.feed_forward.shared_expert.up_proj.weight_scale" : "local_colwise" ,
261+ "layers.*.feed_forward.shared_expert.down_proj.weight" : "local_rowwise" ,
262+ "layers.*.feed_forward.experts" : "local" ,
263+ "layers.*.feed_forward" : "gather" ,
264+ "layers.*.feed_forward.experts.*.gate_proj.weight" : "local_colwise" ,
265+ "layers.*.feed_forward.experts.*.gate_proj.weight_scale" : "local_colwise" ,
266+ "layers.*.feed_forward.experts.*.up_proj.weight" : "local_colwise" ,
267+ "layers.*.feed_forward.experts.*.up_proj.weight_scale" : "local_colwise" ,
268+ "layers.*.feed_forward.experts.*.down_proj.weight" : "local_rowwise" ,
269+ # For Fused implementation
270+ "layers.*.feed_forward.experts.gate_up_proj" : "local_packed_rowwise" ,
271+ "layers.*.feed_forward.experts.gate_up_proj_scale" : "local_packed_rowwise" ,
272+ "layers.*.feed_forward.experts.down_proj" : "local_colwise" ,
273+ }
274+ if config .get_text_config () is not None :
275+ config .get_text_config ().base_model_tp_plan = text_plan
276+ else :
277+ config .base_model_tp_plan = text_plan
278+ return config
279+
244280 def is_serializable (self , safe_serialization = None ):
245281 return True
246282
0 commit comments