release v0.1.4

Former-commit-id: 81f84aaf2e120e39edb28ef42893939fc9a184e2
This commit is contained in:
hiyouga
2023-08-01 10:08:47 +08:00
parent 772ad4ec6b
commit 661890b8a1
13 changed files with 66 additions and 49 deletions

View File

@@ -1,4 +1,4 @@
from llmtuner.chat import ChatModel
__version__ = "0.1.3"
__version__ = "0.1.4"

View File

@@ -93,9 +93,11 @@ def load_model_and_tokenizer(
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:

View File

@@ -32,9 +32,9 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
parser = HfArgumentParser((
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
))
return _parse_args(parser, args)
@@ -51,7 +51,7 @@ def parse_infer_args(
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@@ -79,6 +79,12 @@ def get_train_args(
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
assert not (training_args.max_steps == -1 and data_args.streaming), \
"Please specify `max_steps` in streaming mode."
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently."
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
@@ -108,12 +114,6 @@ def get_train_args(
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
assert not (training_args.max_steps == -1 and data_args.streaming), \
"Please specify `max_steps` in streaming mode."
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently."
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:

View File

@@ -1,16 +1,17 @@
from typing import Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel,
chat_model: "WebChatModel",
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()

View File

@@ -1,10 +1,12 @@
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)

View File

@@ -1,14 +1,16 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)

View File

@@ -1,11 +1,13 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.utils import export_model
if TYPE_CHECKING:
from gradio.components import Component
def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)

View File

@@ -1,13 +1,15 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()

View File

@@ -1,16 +1,18 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)

View File

@@ -1,15 +1,17 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, Component]:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():

View File

@@ -1,15 +1,17 @@
import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component
from typing import TYPE_CHECKING, Any, Dict, List
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.utils import get_time
if TYPE_CHECKING:
from gradio.components import Component
class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
def __init__(self, elem_list: List[Dict[str, "Component"]]):
self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]:
@@ -24,7 +26,7 @@ class Manager:
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]:
def gen_label(self, lang: str) -> Dict["Component", dict]:
update_dict = {}
refresh_dict = self.gen_refresh()