[data] Fix qwen_2vl with valuehead (#9078)
This commit is contained in:
@@ -211,9 +211,23 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
if isinstance(self.pretrained_model, PeftModel):
|
||||
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||
|
||||
def get_rope_index_func(self: "AutoModelForCausalLMWithValueHead"):
|
||||
if isinstance(self.pretrained_model, PeftModel):
|
||||
base_model = self.pretrained_model.base_model.model
|
||||
else:
|
||||
base_model = self.pretrained_model
|
||||
|
||||
if base_model and hasattr(base_model, "get_rope_index"):
|
||||
return base_model.get_rope_index
|
||||
elif (base_model and hasattr(base_model, "model") and hasattr(base_model.model, "get_rope_index")):
|
||||
return base_model.model.get_rope_index
|
||||
else:
|
||||
return None
|
||||
|
||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
|
||||
setattr(model, "get_rope_index", get_rope_index_func(model))
|
||||
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||
|
||||
Reference in New Issue
Block a user