support report custom args
Former-commit-id: d41254c40a1c5cacf9377096adb27efa9bdb79ea
This commit is contained in:
@@ -24,13 +24,14 @@ from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .trainer_utils import get_swanlab_callback
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -44,6 +45,14 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
||||
callbacks.append(LogCallback())
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
callbacks.append(PissaConvertCallback())
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
callbacks.append(get_swanlab_callback(finetuning_args))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
|
||||
Reference in New Issue
Block a user