modify code structure

Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
hiyouga
2023-08-02 23:17:36 +08:00
parent 8bd1da7144
commit 28a51b622b
25 changed files with 188 additions and 145 deletions

View File

@@ -9,7 +9,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
@@ -105,6 +105,7 @@ class Runner:
checkpoint_dir = None
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
do_train=True,
overwrite_cache=True,
@@ -141,16 +142,8 @@ class Runner:
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
@@ -158,7 +151,7 @@ class Runner:
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang)
@@ -194,6 +187,7 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
do_eval=True,
overwrite_cache=True,
@@ -216,16 +210,8 @@ class Runner:
args.pop("do_eval", None)
args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
@@ -233,6 +219,6 @@ class Runner:
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))