Former-commit-id: 79731ae13ecd17eb8646fb53162c81dddfef3b00
This commit is contained in:
hoshi-hiyouga
2025-01-14 18:40:07 +08:00
committed by GitHub
parent 1bb06e06df
commit 41a9e231cb
5 changed files with 11 additions and 7 deletions

View File

@@ -110,7 +110,7 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False)
setattr(config, "init_tts", False)
@@ -119,7 +119,7 @@ def patch_config(
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers >= 4.47.1, please upgrade it.")
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
@@ -153,7 +153,7 @@ def patch_model(
):
gen_config.do_sample = True
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)