update ppo and demo in webui
Former-commit-id: de7571704c82121db13e3fc907379d2453100191
This commit is contained in:
@@ -31,7 +31,6 @@ class Runner:
|
||||
self.thread: "Thread" = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
self.monitor_inputs: Dict[str, str] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
@@ -75,6 +74,7 @@ class Runner:
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
self.thread = None
|
||||
self.running_data = None
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
@@ -87,9 +87,9 @@ class Runner:
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
@@ -139,7 +139,10 @@ class Runner:
|
||||
args["upcast_layernorm"] = True
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
|
||||
args["reward_model"] = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||
)
|
||||
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
@@ -157,9 +160,9 @@ class Runner:
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||
)
|
||||
@@ -216,7 +219,6 @@ class Runner:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
@@ -235,7 +237,10 @@ class Runner:
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
self.running = True
|
||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
|
||||
output_dir = self.running_data[self.manager.get_elem_by_name(
|
||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||
)]
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
|
||||
Reference in New Issue
Block a user