support report custom args

Former-commit-id: d41254c40a1c5cacf9377096adb27efa9bdb79ea
This commit is contained in:
hiyouga
2024-12-19 14:57:09 +00:00
parent adff887659
commit a897d46049
20 changed files with 164 additions and 124 deletions

View File

@@ -171,7 +171,10 @@ class HuggingfaceEngine(BaseEngine):
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(dtype=model.dtype, device=model.device)
if torch.is_floating_point(value):
value = value.to(model.dtype)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length

View File

@@ -15,8 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
@dataclass
@@ -161,3 +161,6 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
@dataclass
@@ -318,7 +318,7 @@ class SwanLabArguments:
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
swanlab_run_name: str = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
@@ -440,3 +440,8 @@ class FinetuningArguments(
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args

View File

@@ -16,7 +16,7 @@
# limitations under the License.
import json
from dataclasses import dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
@@ -344,3 +344,8 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
setattr(result, name, value)
return result
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args

View File

@@ -42,10 +42,13 @@ if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = logging.get_logger(__name__)
@@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint(
@@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
@@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)
class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.model_args = model_args
self.data_args = data_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not state.is_world_process_zero:
return
if "wandb" in args.report_to:
import wandb
wandb.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
if self.finetuning_args.use_swanlab:
import swanlab
swanlab.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)

View File

@@ -30,8 +30,8 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -97,18 +97,12 @@ class CustomDPOTrainer(DPOTrainer):
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -30,7 +30,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -101,9 +101,6 @@ class CustomKTOTrainer(KTOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -40,7 +40,7 @@ from typing_extensions import override
from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -186,9 +186,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.

View File

@@ -20,8 +20,8 @@ from transformers import Trainer
from typing_extensions import override
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
@@ -47,18 +47,12 @@ class CustomTrainer(Trainer):
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -26,8 +26,8 @@ from typing_extensions import override
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
@@ -59,18 +59,12 @@ class PairwiseTrainer(Trainer):
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -28,8 +28,8 @@ from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
@@ -62,18 +62,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -472,9 +472,8 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name,
experiment_name=finetuning_args.swanlab_run_name,
mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LLaMA Factory"},
config={"Framework": "🦙LlamaFactory"},
)
return swanlab_callback
return swanlab_callback

View File

@@ -24,13 +24,14 @@ from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .trainer_utils import get_swanlab_callback
if TYPE_CHECKING:
@@ -44,6 +45,14 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback())
if finetuning_args.use_swanlab:
callbacks.append(get_swanlab_callback(finetuning_args))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":

View File

@@ -273,21 +273,23 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as swanlab_tab:
with gr.Row():
use_swanlab = gr.Checkbox()
swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True)
swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True)
swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True)
swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True)
swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True)
swanlab_project = gr.Textbox(value="llamafactory")
swanlab_run_name = gr.Textbox()
swanlab_workspace = gr.Textbox()
swanlab_api_key = gr.Textbox()
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode})
input_elems.update(
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
)
elem_dict.update(
dict(
swanlab_tab=swanlab_tab,
use_swanlab=use_swanlab,
swanlab_api_key=swanlab_api_key,
swanlab_project=swanlab_project,
swanlab_run_name=swanlab_run_name,
swanlab_workspace=swanlab_workspace,
swanlab_experiment_name=swanlab_experiment_name,
swanlab_api_key=swanlab_api_key,
swanlab_mode=swanlab_mode,
)
)

View File

@@ -1385,86 +1385,85 @@ LOCALES = {
"info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.",
},
},
"swanlab_api_key": {
"en": {
"label": "API Key(optional)",
"info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.",
},
"ru": {
"label": "API ключ(Необязательный)",
"info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.",
},
"zh": {
"label": "API密钥(选填)",
"info": "用于在编程环境登录SwanLab已登录则无需填写。",
},
"ko": {
"label": "API 키(선택 사항)",
"info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.",
},
},
"swanlab_project": {
"en": {
"label": "Project(optional)",
"label": "SwanLab project",
},
"ru": {
"label": "Проект(Необязательный)",
"label": "SwanLab Проект",
},
"zh": {
"label": "项目(选填)",
"label": "SwanLab 项目名",
},
"ko": {
"label": "프로젝트(선택 사항)",
"label": "SwanLab 프로젝트",
},
},
"swanlab_run_name": {
"en": {
"label": "SwanLab experiment name (optional)",
},
"ru": {
"label": "SwanLab Имя эксперимента (опционально)",
},
"zh": {
"label": "SwanLab 实验名(非必填)",
},
"ko": {
"label": "SwanLab 실험 이름 (선택 사항)",
},
},
"swanlab_workspace": {
"en": {
"label": "Workspace(optional)",
"info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.",
"label": "SwanLab workspace (optional)",
"info": "Workspace for SwanLab. Defaults to the personal workspace.",
},
"ru": {
"label": "Рабочая область(Необязательный)",
"label": "SwanLab Рабочая область (опционально)",
"info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.",
},
"zh": {
"label": "Workspace(选填)",
"info": "SwanLab组织的工作区,如不填写则默认在个人工作区下",
"label": "SwanLab 工作区(非必填)",
"info": "SwanLab 的工作区,默认在个人工作区下",
},
"ko": {
"label": "작업 영역(선택 사항)",
"label": "SwanLab 작업 영역 (선택 사항)",
"info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.",
},
},
"swanlab_experiment_name": {
"swanlab_api_key": {
"en": {
"label": "Experiment name (optional)",
"label": "SwanLab API key (optional)",
"info": "API key for SwanLab.",
},
"ru": {
"label": "Имя эксперимента(Необязательный)",
"label": "SwanLab API ключ (опционально)",
"info": "API ключ для SwanLab.",
},
"zh": {
"label": "实验名(选填) ",
"label": "SwanLab API密钥非必填",
"info": "用于在编程环境登录 SwanLab已登录则无需填写。",
},
"ko": {
"label": "실험 이름(선택 사항)",
"label": "SwanLab API 키 (선택 사항)",
"info": "SwanLab의 API 키.",
},
},
"swanlab_mode": {
"en": {
"label": "Mode",
"info": "Cloud or offline version.",
"label": "SwanLab mode",
"info": "Cloud or offline version.",
},
"ru": {
"label": "Режим",
"label": "SwanLab Режим",
"info": "Версия в облаке или локальная версия.",
},
"zh": {
"label": "模式",
"info": "云端版或离线版",
"label": "SwanLab 模式",
"info": "使用云端版或离线版 SwanLab。",
},
"ko": {
"label": "모드",
"label": "SwanLab 모드",
"info": "클라우드 버전 또는 오프라인 버전.",
},
},

View File

@@ -231,12 +231,11 @@ class Runner:
# swanlab config
if get("train.use_swanlab"):
args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_project"] = get("train.swanlab_project")
args["swanlab_run_name"] = get("train.swanlab_run_name")
args["swanlab_workspace"] = get("train.swanlab_workspace")
args["swanlab_experiment_name"] = get("train.swanlab_experiment_name")
args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_mode"] = get("train.swanlab_mode")
# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":