[config] update args (#7231)

Former-commit-id: f71a901840811bf560df671ec63a146ff99140c6
This commit is contained in:
hoshi-hiyouga
2025-03-10 23:04:43 +08:00
committed by GitHub
parent cf58a6d860
commit 71a1c1321a
16 changed files with 89 additions and 74 deletions

View File

@@ -415,15 +415,15 @@ class FinetuningArguments(
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
)
freeze_multi_modal_projector: bool = field(
default=True,
metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."},
)
train_mm_proj_only: bool = field(
freeze_language_model: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
metadata={"help": "Whether or not to freeze the language model in MLLM training."},
)
compute_accuracy: bool = field(
default=False,
@@ -455,8 +455,6 @@ class FinetuningArguments(
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
@@ -484,9 +482,6 @@ class FinetuningArguments(
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")
if self.finetuning_type != "lora":
if self.loraplus_lr_ratio is not None:
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")

View File

@@ -23,6 +23,8 @@ import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass
class BaseModelArguments:
@@ -77,12 +79,12 @@ class BaseModelArguments:
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
rope_scaling: Optional[RopeScaling] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
default="auto",
flash_attn: AttentionFunction = field(
default=AttentionFunction.AUTO,
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
@@ -129,8 +131,8 @@ class BaseModelArguments:
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface",
infer_backend: EngineName = field(
default=EngineName.HF,
metadata={"help": "Backend engine used at inference."},
)
offload_folder: str = field(