add template match and stage in webui
Former-commit-id: d6283e7f041f08f76d18350cb5f6a6c58ca80e92
This commit is contained in:
@@ -8,11 +8,11 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, get_template
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
@@ -70,6 +70,7 @@ class Runner:
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
stage: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
@@ -91,7 +92,6 @@ class Runner:
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
resume_lora_training: bool,
|
||||
rlhf_method: str,
|
||||
dpo_beta: float,
|
||||
reward_model: str,
|
||||
output_dir: str
|
||||
@@ -113,7 +113,7 @@ class Runner:
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
template=get_template(template, model_name),
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
@@ -138,16 +138,18 @@ class Runner:
|
||||
)
|
||||
args[compute_type] = True
|
||||
|
||||
if rlhf_method == "Reward Modeling":
|
||||
if stage == "Pretraining":
|
||||
args["stage"] = "pt"
|
||||
if stage == "Reward Modeling":
|
||||
args["stage"] = "rm"
|
||||
args["resume_lora_training"] = False
|
||||
elif rlhf_method == "PPO":
|
||||
elif stage == "PPO":
|
||||
args["stage"] = "ppo"
|
||||
args["resume_lora_training"] = False
|
||||
args["reward_model"] = reward_model
|
||||
args["padding_side"] = "left"
|
||||
val_size = 0
|
||||
elif rlhf_method == "DPO":
|
||||
elif stage == "DPO":
|
||||
args["stage"] = "dpo"
|
||||
args["resume_lora_training"] = False
|
||||
args["dpo_beta"] = dpo_beta
|
||||
@@ -195,7 +197,7 @@ class Runner:
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
template=get_template(template, model_name),
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
|
||||
Reference in New Issue
Block a user