Skip to content

Commit b1ebe78

Browse files
bvantuanzaristei
authored andcommitted
Fix key mapping for VLMs (huggingface#39029)
* fix key mapping for VLMs * use __mro__ instead * update key mapping in save_pretrained
1 parent 0a9305e commit b1ebe78

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/transformers/modeling_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3746,7 +3746,11 @@ def save_pretrained(
37463746
module_map[name + f".{key}"] = module
37473747
state_dict = model_to_save.state_dict()
37483748

3749-
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
3749+
if any(
3750+
allowed_name in class_name.__name__.lower()
3751+
for class_name in self.__class__.__mro__[:-1]
3752+
for allowed_name in VLMS
3753+
):
37503754
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
37513755

37523756
original_state_dict = {}
@@ -4402,7 +4406,9 @@ def from_pretrained(
44024406

44034407
key_mapping = kwargs.pop("key_mapping", None)
44044408
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
4405-
if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
4409+
if key_mapping is None and any(
4410+
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
4411+
):
44064412
key_mapping = cls._checkpoint_conversion_mapping
44074413

44084414
# Not used anymore -- remove them from the kwargs

0 commit comments

Comments
 (0)