Former-commit-id: 627d1c91e675f1d9ebf47bad123cbbf29821da4d
This commit is contained in:
@@ -52,6 +52,8 @@ class Runner:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
stage = TRAINING_STAGES[get("train.training_stage")]
|
||||
reward_model = get("train.reward_model")
|
||||
|
||||
if self.running:
|
||||
return ALERTS["err_conflict"][lang]
|
||||
@@ -65,6 +67,9 @@ class Runner:
|
||||
if len(dataset) == 0:
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
if stage == "ppo" and not reward_model:
|
||||
return ALERTS["err_no_reward_model"][lang]
|
||||
|
||||
if not from_preview and self.demo_mode:
|
||||
return ALERTS["err_demo"][lang]
|
||||
|
||||
@@ -163,8 +168,11 @@ class Runner:
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||
args["reward_model"] = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("train.reward_model")
|
||||
]
|
||||
)
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user