Former-commit-id: 627d1c91e675f1d9ebf47bad123cbbf29821da4d
This commit is contained in:
hiyouga
2024-03-09 02:01:26 +08:00
parent 2f095e2017
commit 43b2ede0f8
7 changed files with 28 additions and 20 deletions

View File

@@ -52,6 +52,8 @@ class Runner:
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
stage = TRAINING_STAGES[get("train.training_stage")]
reward_model = get("train.reward_model")
if self.running:
return ALERTS["err_conflict"][lang]
@@ -65,6 +67,9 @@ class Runner:
if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang]
if stage == "ppo" and not reward_model:
return ALERTS["err_no_reward_model"][lang]
if not from_preview and self.demo_mode:
return ALERTS["err_demo"][lang]
@@ -163,8 +168,11 @@ class Runner:
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
args["reward_model"] = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("train.reward_model")
]
)
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"