Former-commit-id: 8d035a849c4a441d457791aab073861adf69a09f
This commit is contained in:
BUAADreamer
2024-04-25 21:08:32 +08:00
parent 9b210cf4b3
commit dbd905438b
8 changed files with 80 additions and 283 deletions

View File

@@ -14,7 +14,6 @@ from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .sftmm import run_sft_mm
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -30,8 +29,6 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "sft_mm":
run_sft_mm(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "ppo":