add template match and stage in webui

Former-commit-id: d6283e7f041f08f76d18350cb5f6a6c58ca80e92
This commit is contained in:
codemayq
2023-08-14 20:42:59 +08:00
parent 688e8601ab
commit 9585699918
6 changed files with 77 additions and 14 deletions

View File

@@ -8,11 +8,11 @@ from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.common import get_model_path, get_save_dir, get_template
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@@ -70,6 +70,7 @@ class Runner:
quantization_bit: str,
template: str,
source_prefix: str,
stage: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
@@ -91,7 +92,6 @@ class Runner:
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
rlhf_method: str,
dpo_beta: float,
reward_model: str,
output_dir: str
@@ -113,7 +113,7 @@ class Runner:
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
template=get_template(template, model_name),
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
@@ -138,16 +138,18 @@ class Runner:
)
args[compute_type] = True
if rlhf_method == "Reward Modeling":
if stage == "Pretraining":
args["stage"] = "pt"
if stage == "Reward Modeling":
args["stage"] = "rm"
args["resume_lora_training"] = False
elif rlhf_method == "PPO":
elif stage == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif rlhf_method == "DPO":
elif stage == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta
@@ -195,7 +197,7 @@ class Runner:
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
template=get_template(template, model_name),
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),