[assets] update wechat (#8962)
This commit is contained in:
@@ -776,6 +776,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-3-270M": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3-270m",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-270m",
|
||||
},
|
||||
"Gemma-3-4B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3-4b-pt",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt",
|
||||
@@ -788,6 +792,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "google/gemma-3-27b-pt",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt",
|
||||
},
|
||||
"Gemma-3-270M-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3-270m-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-270m-it",
|
||||
},
|
||||
"Gemma-3-4B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3-4b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it",
|
||||
@@ -1669,8 +1677,8 @@ register_model_group(
|
||||
},
|
||||
"MiMo-VL-7B-RL-2508": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508"
|
||||
}
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
|
||||
},
|
||||
},
|
||||
template="mimo_vl",
|
||||
multimodal=True,
|
||||
@@ -1685,7 +1693,7 @@ register_model_group(
|
||||
},
|
||||
"MiMo-VL-7B-SFT-2508": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508"
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||
},
|
||||
},
|
||||
template="qwen2_vl",
|
||||
|
||||
@@ -32,6 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||
from ..extras.packages import is_transformers_version_greater_than
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -304,6 +305,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
||||
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
@@ -16,11 +16,10 @@ import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig, LoraModel, OFTConfig, OFTModel, PeftModel, TaskType, get_peft_model
|
||||
from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import check_version
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
|
||||
@@ -111,6 +111,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
if self.bco_gemma >= 1e-6:
|
||||
from trl.trainer import RunningMoments
|
||||
|
||||
self.running = RunningMoments(self.accelerator)
|
||||
|
||||
@override
|
||||
@@ -161,14 +162,14 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_logps: "torch.Tensor",
|
||||
rejected_logps: "torch.Tensor",
|
||||
reference_chosen_logps: "torch.Tensor",
|
||||
reference_rejected_logps: "torch.Tensor"
|
||||
reference_rejected_logps: "torch.Tensor",
|
||||
) -> "torch.Tensor":
|
||||
chosen_logratios = chosen_logps - reference_chosen_logps
|
||||
rejected_logratios = rejected_logps - reference_rejected_logps
|
||||
chosen_rewards = self.beta * chosen_logratios
|
||||
rejected_rewards = self.beta * rejected_logratios
|
||||
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
||||
self.running.update(rewards) # update baseline
|
||||
self.running.update(rewards) # update baseline
|
||||
delta = self.running.mean
|
||||
bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
|
||||
-(self.beta * rejected_logratios - delta)
|
||||
@@ -195,15 +196,12 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||
else:
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
|
||||
if self.bco_gemma > 1e-6:
|
||||
bco_losses = self.bco_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
losses += bco_losses * self.bco_gemma
|
||||
|
||||
@@ -288,7 +286,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
losses += self.ftx_gamma * sft_loss
|
||||
if self.bco_gemma > 1e-6:
|
||||
# re-weigthing for MPO
|
||||
losses /= (self.ftx_gamma + self.bco_gemma + 1.0)
|
||||
losses /= self.ftx_gamma + self.bco_gemma + 1.0
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
||||
|
||||
Reference in New Issue
Block a user