modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.interface import create_ui
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.components import create_chat_box
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, *args):
|
||||
def __init__(self, args: Optional[Dict[str, Any]]) -> None:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if len(args) != 0:
|
||||
super().__init__(*args)
|
||||
if args is not None:
|
||||
super().__init__(args)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
@@ -57,7 +56,7 @@ class WebChatModel(ChatModel):
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
super().__init__(*get_infer_args(args))
|
||||
super().__init__(args)
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
|
||||
@@ -3,3 +3,4 @@ from llmtuner.webui.components.sft import create_sft_tab
|
||||
from llmtuner.webui.components.eval import create_eval_tab
|
||||
from llmtuner.webui.components.infer import create_infer_tab
|
||||
from llmtuner.webui.components.export import create_export_tab
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.webui.utils import export_model
|
||||
from llmtuner.webui.utils import save_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
@@ -16,7 +16,7 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
export_btn.click(
|
||||
export_model,
|
||||
save_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
|
||||
@@ -47,6 +47,7 @@ def create_ui() -> gr.Blocks:
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
queue=False
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
@@ -9,7 +9,7 @@ from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import get_train_args, run_sft
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import format_info, get_eval_results
|
||||
@@ -105,6 +105,7 @@ class Runner:
|
||||
checkpoint_dir = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
overwrite_cache=True,
|
||||
@@ -141,16 +142,8 @@ class Runner:
|
||||
args["eval_steps"] = save_steps
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
run_kwargs = dict(args=args, callbacks=[trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
@@ -158,7 +151,7 @@ class Runner:
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
|
||||
yield self.finalize(lang)
|
||||
|
||||
@@ -194,6 +187,7 @@ class Runner:
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_eval=True,
|
||||
overwrite_cache=True,
|
||||
@@ -216,16 +210,8 @@ class Runner:
|
||||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
run_kwargs = dict(args=args, callbacks=[trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
@@ -233,6 +219,6 @@ class Runner:
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
|
||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
||||
|
||||
@@ -3,20 +3,23 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
|
||||
def format_info(log: str, tracker: dict) -> str:
|
||||
|
||||
def format_info(log: str, callback: "LogCallback") -> str:
|
||||
info = log
|
||||
if "current_steps" in tracker:
|
||||
if callback.max_steps:
|
||||
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
||||
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return info
|
||||
|
||||
@@ -87,7 +90,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
return fig
|
||||
|
||||
|
||||
def export_model(
|
||||
def save_model(
|
||||
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
@@ -114,12 +117,10 @@ def export_model(
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type
|
||||
finetuning_type=finetuning_type,
|
||||
output_dir=save_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
Reference in New Issue
Block a user