Skip to content

Commit 0a0166a

Browse files
add changed
1 parent 800510c commit 0a0166a

File tree

4 files changed

+230
-99
lines changed

4 files changed

+230
-99
lines changed

src/transformers/models/glm/configuration_glm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
eos_token_id=[151329, 151336, 151338],
122122
bos_token_id=None,
123123
attention_bias=True,
124+
sandwich=False,
124125
**kwargs,
125126
):
126127
self.vocab_size = vocab_size
@@ -139,6 +140,7 @@ def __init__(
139140
self.rope_theta = rope_theta
140141
self.attention_bias = attention_bias
141142
self.attention_dropout = attention_dropout
143+
self.sandwich = sandwich
142144

143145
super().__init__(
144146
pad_token_id=pad_token_id,

src/transformers/models/glm/convert_glm_weights_to_hf.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,36 @@
99

1010
from transformers import GlmConfig, GlmForCausalLM, PreTrainedTokenizerFast
1111

12-
1312
# fmt: off
1413
# `None` means we drop the key
15-
STATE_DICT_MAPPING = {
14+
BASE_STATE_DICT_MAPPING = {
1615
# CausalLM keys
17-
r"transformer.output_layer.weight": r"lm_head.weight",
16+
r"transformer.output_layer.weight": r"lm_head.weight",
1817

1918
# Model keys
20-
r"transformer.embedding.word_embeddings.weight": r"model.embed_tokens.weight",
21-
r"transformer.rotary_pos_emb.inv_freq": None,
22-
r"transformer.encoder.final_layernorm.weight": r"model.norm.weight",
19+
r"transformer.embedding.word_embeddings.weight": r"model.embed_tokens.weight",
20+
r"transformer.rotary_pos_emb.inv_freq": None,
21+
r"transformer.encoder.final_layernorm.weight": r"model.norm.weight",
2322

2423
# Layers keys
25-
r"transformer.encoder.layers.(\d+).input_layernorm.weight": r"model.layers.\1.input_layernorm.weight",
26-
r"transformer.encoder.layers.(\d+).post_attention_layernorm.weight": r"model.layers.\1.post_attention_layernorm.weight",
24+
r"transformer.encoder.layers.(\d+).input_layernorm.weight": r"model.layers.\1.input_layernorm.weight",
25+
r"transformer.encoder.layers.(\d+).post_attention_layernorm.weight": r"model.layers.\1.post_attention_layernorm.weight",
2726

2827
# Attention keys
29-
r"transformer.encoder.layers.(\d+).self_attention.dense.weight": r"model.layers.\1.self_attn.o_proj.weight",
28+
r"transformer.encoder.layers.(\d+).self_attention.dense.weight": r"model.layers.\1.self_attn.o_proj.weight",
3029
# qkv_proj will later be split in q|k|v|_proj
3130
r"transformer.encoder.layers.(\d+).self_attention.query_key_value.(weight|bias)": r"model.layers.\1.self_attn.qkv_proj.\2",
3231

3332
# MLP keys
34-
r"transformer.encoder.layers.(\d+).mlp.dense_h_to_4h.weight": r"model.layers.\1.mlp.gate_up_proj.weight",
35-
r"transformer.encoder.layers.(\d+).mlp.dense_4h_to_h.weight": r"model.layers.\1.mlp.down_proj.weight",
33+
r"transformer.encoder.layers.(\d+).mlp.dense_h_to_4h.weight": r"model.layers.\1.mlp.gate_up_proj.weight",
34+
r"transformer.encoder.layers.(\d+).mlp.dense_4h_to_h.weight": r"model.layers.\1.mlp.down_proj.weight",
35+
}
36+
37+
# Additional mappings for sandwich mode
38+
SANDWICH_STATE_DICT_MAPPING = {
39+
r"transformer.encoder.layers.(\d+).post_mlp_layernorm.weight": r"model.layers.\1.post_mlp_layernorm.weight",
40+
r"transformer.encoder.layers.(\d+).post_self_attn_layernorm.weight": r"model.layers.\1.post_self_attn_layernorm.weight",
3641
}
37-
# fmt: on
3842

3943

4044
def load_weights(input_dir: str):
@@ -61,8 +65,8 @@ def load_weights(input_dir: str):
6165
raise ValueError("No .safetensors or .bin files found in the specified directory.")
6266

6367

64-
def map_old_key_to_new(old_key):
65-
for pattern, replacement in STATE_DICT_MAPPING.items():
68+
def map_old_key_to_new(old_key, state_dict_mapping):
69+
for pattern, replacement in state_dict_mapping.items():
6670
if replacement is None:
6771
if re.fullmatch(pattern, old_key):
6872
return None
@@ -75,33 +79,43 @@ def map_old_key_to_new(old_key):
7579
raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).")
7680

7781

78-
def convert_state_dict(original_state_dict: dict, config: GlmConfig):
82+
def convert_state_dict(original_state_dict: dict, config: GlmConfig, use_sandwich: bool = False):
7983
new_dict = {}
8084

8185
head_dim = config.hidden_size // config.num_attention_heads
8286
query_size = config.num_attention_heads * head_dim
8387
kv_size = config.num_key_value_heads * head_dim
8488

89+
# Combine the base mapping with sandwich mapping if sandwich mode is enabled
90+
state_dict_mapping = BASE_STATE_DICT_MAPPING.copy()
91+
if use_sandwich:
92+
state_dict_mapping.update(SANDWICH_STATE_DICT_MAPPING)
93+
8594
for old_key, value in original_state_dict.items():
86-
new_key = map_old_key_to_new(old_key)
87-
if new_key is None:
95+
try:
96+
new_key = map_old_key_to_new(old_key, state_dict_mapping)
97+
if new_key is None:
98+
continue
99+
100+
if "qkv_proj." in new_key:
101+
q_proj, k_proj, v_proj = (
102+
value[:query_size, ...],
103+
value[query_size : query_size + kv_size, ...],
104+
value[query_size + kv_size :, ...],
105+
)
106+
new_dict[new_key.replace("qkv_proj.", "q_proj.")] = q_proj
107+
new_dict[new_key.replace("qkv_proj.", "k_proj.")] = k_proj
108+
new_dict[new_key.replace("qkv_proj.", "v_proj.")] = v_proj
109+
else:
110+
new_dict[new_key] = value
111+
except ValueError:
112+
# Skip keys that couldn't be mapped
88113
continue
89114

90-
if "qkv_proj." in new_key:
91-
q_proj, k_proj, v_proj = (
92-
value[:query_size, ...],
93-
value[query_size : query_size + kv_size, ...],
94-
value[query_size + kv_size :, ...],
95-
)
96-
new_dict[new_key.replace("qkv_proj.", "q_proj.")] = q_proj
97-
new_dict[new_key.replace("qkv_proj.", "k_proj.")] = k_proj
98-
new_dict[new_key.replace("qkv_proj.", "v_proj.")] = v_proj
99-
else:
100-
new_dict[new_key] = value
101115
return new_dict
102116

103117

104-
def convert_config(original_config: dict):
118+
def convert_config(original_config: dict, use_sandwich: bool = False):
105119
key_mapping = {
106120
"vocab_size": "padded_vocab_size",
107121
"intermediate_size": "ffn_hidden_size",
@@ -128,6 +142,7 @@ def convert_config(original_config: dict):
128142
else original_config["multi_query_group_num"]
129143
)
130144
new_config_kwargs["rope_theta"] = 10000.0 * getattr(original_config, "rope_ratio", 1)
145+
new_config_kwargs["sandwich"] = use_sandwich
131146

132147
new_config = GlmConfig(**new_config_kwargs)
133148
return new_config
@@ -153,16 +168,16 @@ def convert_glm_tokenizer(input_dir, use_post_processor=False):
153168
return fast_tok
154169

155170

156-
def convert_glm_model(input_dir, output_dir, use_post_processor=False):
171+
def convert_glm_model(input_dir, output_dir, use_post_processor=False, use_sandwich=False):
157172
# Load and convert config
158173
with open(os.path.join(input_dir, "config.json")) as f:
159174
original_config = json.load(f)
160-
config = convert_config(original_config)
175+
config = convert_config(original_config, use_sandwich)
161176
config.save_pretrained(output_dir)
162177

163178
# Load and convert weights
164179
original_state_dict = load_weights(input_dir)
165-
new_dict = convert_state_dict(original_state_dict, config)
180+
new_dict = convert_state_dict(original_state_dict, config, use_sandwich)
166181
with torch.device("meta"):
167182
model = GlmForCausalLM(config)
168183
model.load_state_dict(new_dict, strict=True, assign=True)
@@ -190,6 +205,10 @@ def convert_glm_model(input_dir, output_dir, use_post_processor=False):
190205
action="store_true",
191206
help="Whether to apply post processor with special tokens",
192207
)
193-
208+
parser.add_argument(
209+
"--sandwich",
210+
action="store_true",
211+
help="Whether to use two GlmRMSNorm",
212+
)
194213
args = parser.parse_args()
195-
convert_glm_model(args.input_dir, args.output_dir, args.use_post_processor)
214+
convert_glm_model(args.input_dir, args.output_dir, args.use_post_processor, args.sandwich)

src/transformers/models/glm/modeling_glm.py

Lines changed: 67 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,71 @@ def forward(
244244
return attn_output, attn_weights
245245

246246

247+
class GlmDecoderLayer(nn.Module):
248+
def __init__(self, config: GlmConfig, layer_idx: int):
249+
super().__init__()
250+
self.hidden_size = config.hidden_size
251+
self.sandwich = config.sandwich
252+
self.self_attn = GlmAttention(config=config, layer_idx=layer_idx)
253+
254+
self.mlp = GlmMLP(config)
255+
self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256+
self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
257+
258+
if self.sandwich:
259+
self.post_self_attn_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
260+
self.post_mlp_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
261+
262+
def forward(
263+
self,
264+
hidden_states: torch.Tensor,
265+
attention_mask: Optional[torch.Tensor] = None,
266+
position_ids: Optional[torch.LongTensor] = None,
267+
past_key_value: Optional[Cache] = None,
268+
output_attentions: Optional[bool] = False,
269+
use_cache: Optional[bool] = False,
270+
cache_position: Optional[torch.LongTensor] = None,
271+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
272+
**kwargs: Unpack[FlashAttentionKwargs],
273+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
274+
residual = hidden_states
275+
276+
hidden_states = self.input_layernorm(hidden_states)
277+
278+
# Self Attention
279+
hidden_states, self_attn_weights = self.self_attn(
280+
hidden_states=hidden_states,
281+
attention_mask=attention_mask,
282+
position_ids=position_ids,
283+
past_key_value=past_key_value,
284+
output_attentions=output_attentions,
285+
use_cache=use_cache,
286+
cache_position=cache_position,
287+
position_embeddings=position_embeddings,
288+
**kwargs,
289+
)
290+
if self.sandwich:
291+
hidden_states = self.post_self_attn_layernorm(hidden_states)
292+
hidden_states = residual + hidden_states
293+
294+
# Fully Connected
295+
residual = hidden_states
296+
hidden_states = self.post_attention_layernorm(hidden_states)
297+
hidden_states = self.mlp(hidden_states)
298+
if self.sandwich:
299+
hidden_states = self.post_mlp_layernorm(hidden_states)
300+
hidden_states = residual + hidden_states
301+
302+
outputs = (hidden_states,)
303+
if output_attentions:
304+
outputs += (self_attn_weights,)
305+
306+
return outputs
307+
308+
309+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
310+
311+
247312
class GlmRMSNorm(nn.Module):
248313
def __init__(self, hidden_size, eps=1e-6):
249314
"""
@@ -325,60 +390,6 @@ def forward(self, x, position_ids):
325390
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
326391

327392

328-
class GlmDecoderLayer(nn.Module):
329-
def __init__(self, config: GlmConfig, layer_idx: int):
330-
super().__init__()
331-
self.hidden_size = config.hidden_size
332-
333-
self.self_attn = GlmAttention(config=config, layer_idx=layer_idx)
334-
335-
self.mlp = GlmMLP(config)
336-
self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
337-
self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
338-
339-
def forward(
340-
self,
341-
hidden_states: torch.Tensor,
342-
attention_mask: Optional[torch.Tensor] = None,
343-
position_ids: Optional[torch.LongTensor] = None,
344-
past_key_value: Optional[Cache] = None,
345-
output_attentions: Optional[bool] = False,
346-
use_cache: Optional[bool] = False,
347-
cache_position: Optional[torch.LongTensor] = None,
348-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
349-
**kwargs: Unpack[FlashAttentionKwargs],
350-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
351-
residual = hidden_states
352-
353-
hidden_states = self.input_layernorm(hidden_states)
354-
355-
# Self Attention
356-
hidden_states, self_attn_weights = self.self_attn(
357-
hidden_states=hidden_states,
358-
attention_mask=attention_mask,
359-
position_ids=position_ids,
360-
past_key_value=past_key_value,
361-
output_attentions=output_attentions,
362-
use_cache=use_cache,
363-
cache_position=cache_position,
364-
position_embeddings=position_embeddings,
365-
**kwargs,
366-
)
367-
hidden_states = residual + hidden_states
368-
369-
# Fully Connected
370-
residual = hidden_states
371-
hidden_states = self.post_attention_layernorm(hidden_states)
372-
hidden_states = self.mlp(hidden_states)
373-
hidden_states = residual + hidden_states
374-
375-
outputs = (hidden_states,)
376-
if output_attentions:
377-
outputs += (self_attn_weights,)
378-
379-
return outputs
380-
381-
382393
GLM_START_DOCSTRING = r"""
383394
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
384395
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -765,9 +776,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
765776
return causal_mask
766777

767778

768-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
769-
770-
771779
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
772780
_tied_weights_keys = ["lm_head.weight"]
773781
_tp_plan = {"lm_head": "colwise_rep"}
@@ -839,8 +847,8 @@ def forward(
839847
```python
840848
>>> from transformers import AutoTokenizer, GlmForCausalLM
841849
842-
>>> model = GlmForCausalLM.from_pretrained("meta-glm/Glm-2-7b-hf")
843-
>>> tokenizer = AutoTokenizer.from_pretrained("meta-glm/Glm-2-7b-hf")
850+
>>> model = GlmForCausalLM.from_pretrained("THUDM/glm-4-9b-chat-hf")
851+
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat-hf")
844852
845853
>>> prompt = "Hey, are you conscious? Can you talk to me?"
846854
>>> inputs = tokenizer(prompt, return_tensors="pt")

0 commit comments

Comments
 (0)