update ppo and demo in webui

Former-commit-id: de7571704c82121db13e3fc907379d2453100191
This commit is contained in:
hiyouga
2023-11-16 14:55:26 +08:00
parent f9d4e37b3c
commit df83def566
4 changed files with 32 additions and 16 deletions

View File

@@ -31,7 +31,6 @@ class Runner:
self.thread: "Thread" = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
self.monitor_inputs: Dict[str, str] = None
""" State """
self.aborted = False
self.running = False
@@ -75,6 +74,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str:
self.thread = None
self.running_data = None
self.running = False
torch_gc()
if self.aborted:
@@ -87,9 +87,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
else:
checkpoint_dir = None
@@ -139,7 +139,10 @@ class Runner:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
)
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
@@ -157,9 +160,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
)
@@ -216,7 +219,6 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.do_train, self.running_data = do_train, data
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
yield from self.monitor()
@@ -235,7 +237,10 @@ class Runner:
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.running = True
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
output_dir = self.running_data[self.manager.get_elem_by_name(
"{}.output_dir".format("train" if self.do_train else "eval")
)]
while self.thread.is_alive():
time.sleep(2)
if self.aborted: