modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
@@ -9,7 +9,7 @@ from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import get_train_args, run_sft
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import format_info, get_eval_results
|
||||
@@ -105,6 +105,7 @@ class Runner:
|
||||
checkpoint_dir = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
overwrite_cache=True,
|
||||
@@ -141,16 +142,8 @@ class Runner:
|
||||
args["eval_steps"] = save_steps
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
run_kwargs = dict(args=args, callbacks=[trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
@@ -158,7 +151,7 @@ class Runner:
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
|
||||
yield self.finalize(lang)
|
||||
|
||||
@@ -194,6 +187,7 @@ class Runner:
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_eval=True,
|
||||
overwrite_cache=True,
|
||||
@@ -216,16 +210,8 @@ class Runner:
|
||||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
run_kwargs = dict(args=args, callbacks=[trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
@@ -233,6 +219,6 @@ class Runner:
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
|
||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
||||
|
||||
Reference in New Issue
Block a user