[assets] update wechat (#8962)

This commit is contained in:
Yaowei Zheng
2025-08-19 02:55:09 +08:00
committed by GitHub
parent 003a2acb1a
commit 2c31279316
8 changed files with 29 additions and 22 deletions

View File

@@ -776,6 +776,10 @@ register_model_group(
register_model_group(
models={
"Gemma-3-270M": {
DownloadSource.DEFAULT: "google/gemma-3-270m",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-270m",
},
"Gemma-3-4B": {
DownloadSource.DEFAULT: "google/gemma-3-4b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt",
@@ -788,6 +792,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-3-27b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt",
},
"Gemma-3-270M-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-270m-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-270m-it",
},
"Gemma-3-4B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-4b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it",
@@ -1669,8 +1677,8 @@ register_model_group(
},
"MiMo-VL-7B-RL-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508"
}
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
},
},
template="mimo_vl",
multimodal=True,
@@ -1685,7 +1693,7 @@ register_model_group(
},
"MiMo-VL-7B-SFT-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508"
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
},
},
template="qwen2_vl",

View File

@@ -32,6 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_transformers_version_greater_than
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -304,6 +305,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)

View File

@@ -16,11 +16,10 @@ import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, OFTConfig, OFTModel, PeftModel, TaskType, get_peft_model
from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.misc import check_version
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model

View File

@@ -111,6 +111,7 @@ class CustomDPOTrainer(DPOTrainer):
if self.bco_gemma >= 1e-6:
from trl.trainer import RunningMoments
self.running = RunningMoments(self.accelerator)
@override
@@ -161,14 +162,14 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps: "torch.Tensor",
rejected_logps: "torch.Tensor",
reference_chosen_logps: "torch.Tensor",
reference_rejected_logps: "torch.Tensor"
reference_rejected_logps: "torch.Tensor",
) -> "torch.Tensor":
chosen_logratios = chosen_logps - reference_chosen_logps
rejected_logratios = rejected_logps - reference_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
self.running.update(rewards) # update baseline
self.running.update(rewards) # update baseline
delta = self.running.mean
bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
@@ -195,15 +196,12 @@ class CustomDPOTrainer(DPOTrainer):
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
else:
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
if self.bco_gemma > 1e-6:
bco_losses = self.bco_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
losses += bco_losses * self.bco_gemma
@@ -288,7 +286,7 @@ class CustomDPOTrainer(DPOTrainer):
losses += self.ftx_gamma * sft_loss
if self.bco_gemma > 1e-6:
# re-weigthing for MPO
losses /= (self.ftx_gamma + self.bco_gemma + 1.0)
losses /= self.ftx_gamma + self.bco_gemma + 1.0
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()