mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
[misc] lint code (#9395)
This commit is contained in:
@@ -16,4 +16,3 @@ from .workflow import run_dpo, run_pt, run_sft
|
||||
|
||||
|
||||
__all__ = ["run_dpo", "run_pt", "run_sft"]
|
||||
|
||||
|
||||
@@ -75,12 +75,17 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _check_model_support(model_args: ModelArguments):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
||||
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
@@ -161,22 +166,23 @@ def run_sft(
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]):
|
||||
p.requires_grad_(False)
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in ["multi_modal_projector"]):
|
||||
p.requires_grad_(False)
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]):
|
||||
p.requires_grad_(False)
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
|
||||
pad_to_max = (
|
||||
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
)
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
params_to_freeze.extend(["multi_modal_projector"])
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||
|
||||
if params_to_freeze:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in params_to_freeze):
|
||||
p.requires_grad_(False)
|
||||
|
||||
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
padding="max_length" if pad_to_max else "longest",
|
||||
@@ -239,9 +245,7 @@ def run_dpo(
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
data_args.cutoff_len -= 1
|
||||
|
||||
pad_to_max = (
|
||||
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
)
|
||||
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
dpo_config = DPOConfig(
|
||||
beta=finetuning_args.pref_beta,
|
||||
pref_loss=finetuning_args.pref_loss,
|
||||
@@ -289,4 +293,3 @@ def run_dpo(
|
||||
keys += ["eval_loss"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
|
||||
@@ -71,13 +71,17 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||
if finetuning_args.stage == "pt":
|
||||
from .mca import run_pt as run_pt_mca
|
||||
|
||||
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
from .mca import run_sft as run_sft_mca
|
||||
|
||||
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else: # dpo
|
||||
elif finetuning_args.stage == "dpo":
|
||||
from .mca import run_dpo as run_dpo_mca
|
||||
|
||||
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
|
||||
Reference in New Issue
Block a user