99
1010from transformers import GlmConfig , GlmForCausalLM , PreTrainedTokenizerFast
1111
12-
1312# fmt: off
1413# `None` means we drop the key
15- STATE_DICT_MAPPING = {
14+ BASE_STATE_DICT_MAPPING = {
1615 # CausalLM keys
17- r"transformer.output_layer.weight" : r"lm_head.weight" ,
16+ r"transformer.output_layer.weight" : r"lm_head.weight" ,
1817
1918 # Model keys
20- r"transformer.embedding.word_embeddings.weight" : r"model.embed_tokens.weight" ,
21- r"transformer.rotary_pos_emb.inv_freq" : None ,
22- r"transformer.encoder.final_layernorm.weight" : r"model.norm.weight" ,
19+ r"transformer.embedding.word_embeddings.weight" : r"model.embed_tokens.weight" ,
20+ r"transformer.rotary_pos_emb.inv_freq" : None ,
21+ r"transformer.encoder.final_layernorm.weight" : r"model.norm.weight" ,
2322
2423 # Layers keys
25- r"transformer.encoder.layers.(\d+).input_layernorm.weight" : r"model.layers.\1.input_layernorm.weight" ,
26- r"transformer.encoder.layers.(\d+).post_attention_layernorm.weight" : r"model.layers.\1.post_attention_layernorm.weight" ,
24+ r"transformer.encoder.layers.(\d+).input_layernorm.weight" : r"model.layers.\1.input_layernorm.weight" ,
25+ r"transformer.encoder.layers.(\d+).post_attention_layernorm.weight" : r"model.layers.\1.post_attention_layernorm.weight" ,
2726
2827 # Attention keys
29- r"transformer.encoder.layers.(\d+).self_attention.dense.weight" : r"model.layers.\1.self_attn.o_proj.weight" ,
28+ r"transformer.encoder.layers.(\d+).self_attention.dense.weight" : r"model.layers.\1.self_attn.o_proj.weight" ,
3029 # qkv_proj will later be split in q|k|v|_proj
3130 r"transformer.encoder.layers.(\d+).self_attention.query_key_value.(weight|bias)" : r"model.layers.\1.self_attn.qkv_proj.\2" ,
3231
3332 # MLP keys
34- r"transformer.encoder.layers.(\d+).mlp.dense_h_to_4h.weight" : r"model.layers.\1.mlp.gate_up_proj.weight" ,
35- r"transformer.encoder.layers.(\d+).mlp.dense_4h_to_h.weight" : r"model.layers.\1.mlp.down_proj.weight" ,
33+ r"transformer.encoder.layers.(\d+).mlp.dense_h_to_4h.weight" : r"model.layers.\1.mlp.gate_up_proj.weight" ,
34+ r"transformer.encoder.layers.(\d+).mlp.dense_4h_to_h.weight" : r"model.layers.\1.mlp.down_proj.weight" ,
35+ }
36+
37+ # Additional mappings for sandwich mode
38+ SANDWICH_STATE_DICT_MAPPING = {
39+ r"transformer.encoder.layers.(\d+).post_mlp_layernorm.weight" : r"model.layers.\1.post_mlp_layernorm.weight" ,
40+ r"transformer.encoder.layers.(\d+).post_self_attn_layernorm.weight" : r"model.layers.\1.post_self_attn_layernorm.weight" ,
3641}
37- # fmt: on
3842
3943
4044def load_weights (input_dir : str ):
@@ -61,8 +65,8 @@ def load_weights(input_dir: str):
6165 raise ValueError ("No .safetensors or .bin files found in the specified directory." )
6266
6367
64- def map_old_key_to_new (old_key ):
65- for pattern , replacement in STATE_DICT_MAPPING .items ():
68+ def map_old_key_to_new (old_key , state_dict_mapping ):
69+ for pattern , replacement in state_dict_mapping .items ():
6670 if replacement is None :
6771 if re .fullmatch (pattern , old_key ):
6872 return None
@@ -75,33 +79,43 @@ def map_old_key_to_new(old_key):
7579 raise ValueError (f"Key: { old_key } could not be mapped (check the mapping)." )
7680
7781
78- def convert_state_dict (original_state_dict : dict , config : GlmConfig ):
82+ def convert_state_dict (original_state_dict : dict , config : GlmConfig , use_sandwich : bool = False ):
7983 new_dict = {}
8084
8185 head_dim = config .hidden_size // config .num_attention_heads
8286 query_size = config .num_attention_heads * head_dim
8387 kv_size = config .num_key_value_heads * head_dim
8488
89+ # Combine the base mapping with sandwich mapping if sandwich mode is enabled
90+ state_dict_mapping = BASE_STATE_DICT_MAPPING .copy ()
91+ if use_sandwich :
92+ state_dict_mapping .update (SANDWICH_STATE_DICT_MAPPING )
93+
8594 for old_key , value in original_state_dict .items ():
86- new_key = map_old_key_to_new (old_key )
87- if new_key is None :
95+ try :
96+ new_key = map_old_key_to_new (old_key , state_dict_mapping )
97+ if new_key is None :
98+ continue
99+
100+ if "qkv_proj." in new_key :
101+ q_proj , k_proj , v_proj = (
102+ value [:query_size , ...],
103+ value [query_size : query_size + kv_size , ...],
104+ value [query_size + kv_size :, ...],
105+ )
106+ new_dict [new_key .replace ("qkv_proj." , "q_proj." )] = q_proj
107+ new_dict [new_key .replace ("qkv_proj." , "k_proj." )] = k_proj
108+ new_dict [new_key .replace ("qkv_proj." , "v_proj." )] = v_proj
109+ else :
110+ new_dict [new_key ] = value
111+ except ValueError :
112+ # Skip keys that couldn't be mapped
88113 continue
89114
90- if "qkv_proj." in new_key :
91- q_proj , k_proj , v_proj = (
92- value [:query_size , ...],
93- value [query_size : query_size + kv_size , ...],
94- value [query_size + kv_size :, ...],
95- )
96- new_dict [new_key .replace ("qkv_proj." , "q_proj." )] = q_proj
97- new_dict [new_key .replace ("qkv_proj." , "k_proj." )] = k_proj
98- new_dict [new_key .replace ("qkv_proj." , "v_proj." )] = v_proj
99- else :
100- new_dict [new_key ] = value
101115 return new_dict
102116
103117
104- def convert_config (original_config : dict ):
118+ def convert_config (original_config : dict , use_sandwich : bool = False ):
105119 key_mapping = {
106120 "vocab_size" : "padded_vocab_size" ,
107121 "intermediate_size" : "ffn_hidden_size" ,
@@ -128,6 +142,7 @@ def convert_config(original_config: dict):
128142 else original_config ["multi_query_group_num" ]
129143 )
130144 new_config_kwargs ["rope_theta" ] = 10000.0 * getattr (original_config , "rope_ratio" , 1 )
145+ new_config_kwargs ["sandwich" ] = use_sandwich
131146
132147 new_config = GlmConfig (** new_config_kwargs )
133148 return new_config
@@ -153,16 +168,16 @@ def convert_glm_tokenizer(input_dir, use_post_processor=False):
153168 return fast_tok
154169
155170
156- def convert_glm_model (input_dir , output_dir , use_post_processor = False ):
171+ def convert_glm_model (input_dir , output_dir , use_post_processor = False , use_sandwich = False ):
157172 # Load and convert config
158173 with open (os .path .join (input_dir , "config.json" )) as f :
159174 original_config = json .load (f )
160- config = convert_config (original_config )
175+ config = convert_config (original_config , use_sandwich )
161176 config .save_pretrained (output_dir )
162177
163178 # Load and convert weights
164179 original_state_dict = load_weights (input_dir )
165- new_dict = convert_state_dict (original_state_dict , config )
180+ new_dict = convert_state_dict (original_state_dict , config , use_sandwich )
166181 with torch .device ("meta" ):
167182 model = GlmForCausalLM (config )
168183 model .load_state_dict (new_dict , strict = True , assign = True )
@@ -190,6 +205,10 @@ def convert_glm_model(input_dir, output_dir, use_post_processor=False):
190205 action = "store_true" ,
191206 help = "Whether to apply post processor with special tokens" ,
192207 )
193-
208+ parser .add_argument (
209+ "--sandwich" ,
210+ action = "store_true" ,
211+ help = "Whether to use two GlmRMSNorm" ,
212+ )
194213 args = parser .parse_args ()
195- convert_glm_model (args .input_dir , args .output_dir , args .use_post_processor )
214+ convert_glm_model (args .input_dir , args .output_dir , args .use_post_processor , args . sandwich )
0 commit comments