modify code structure

Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
hiyouga
2023-08-02 23:17:36 +08:00
parent 8bd1da7144
commit 28a51b622b
25 changed files with 188 additions and 145 deletions

View File

@@ -0,0 +1,4 @@
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.interface import create_ui
from llmtuner.webui.manager import Manager
from llmtuner.webui.components import create_chat_box

View File

@@ -1,22 +1,21 @@
import os
from typing import List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self, *args):
def __init__(self, args: Optional[Dict[str, Any]]) -> None:
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if len(args) != 0:
super().__init__(*args)
if args is not None:
super().__init__(args)
def load_model(
self,
@@ -57,7 +56,7 @@ class WebChatModel(ChatModel):
template=template,
source_prefix=source_prefix
)
super().__init__(*get_infer_args(args))
super().__init__(args)
yield ALERTS["info_loaded"][lang]

View File

@@ -3,3 +3,4 @@ from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box

View File

@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from llmtuner.webui.utils import export_model
from llmtuner.webui.utils import save_model
if TYPE_CHECKING:
from gradio.components import Component
@@ -16,7 +16,7 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
export_model,
save_model,
[
top_elems["lang"],
top_elems["model_name"],

View File

@@ -47,6 +47,7 @@ def create_ui() -> gr.Blocks:
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
queue=False
)
return demo

View File

@@ -9,7 +9,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
@@ -105,6 +105,7 @@ class Runner:
checkpoint_dir = None
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
do_train=True,
overwrite_cache=True,
@@ -141,16 +142,8 @@ class Runner:
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
@@ -158,7 +151,7 @@ class Runner:
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang)
@@ -194,6 +187,7 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
do_eval=True,
overwrite_cache=True,
@@ -216,16 +210,8 @@ class Runner:
args.pop("do_eval", None)
args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
@@ -233,6 +219,6 @@ class Runner:
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))

View File

@@ -3,20 +3,23 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, tracker: dict) -> str:
def format_info(log: str, callback: "LogCallback") -> str:
info = log
if "current_steps" in tracker:
if callback.max_steps:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return info
@@ -87,7 +90,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
return fig
def export_model(
def save_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
) -> Generator[str, None, None]:
if not model_name:
@@ -114,12 +117,10 @@ def export_model(
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type
finetuning_type=finetuning_type,
output_dir=save_dir
)
yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
export_model(args, max_shard_size="{}GB".format(max_shard_size))
yield ALERTS["info_exported"][lang]