12 Commits

Author SHA1 Message Date
hiyouga
35e76879f5 support dev set in web ui
Former-commit-id: fe1370561a9b027d9ebdef52733344f1e3683081
2023-07-18 20:40:49 +08:00
hiyouga
8e4ae0aaac add web demo
Former-commit-id: 25ea647e5ac36b497b8e176b123fdee39be3fd30
2023-07-18 17:21:16 +08:00
hiyouga
5ed2a97056 update baichuan template
Former-commit-id: 03520588c39986c98a0515a64993af8c2468b9d0
2023-07-18 16:43:51 +08:00
hiyouga
03eba6f041 fix template
Former-commit-id: 729053c9cea6254165ae9c8fd7809479b12f735c
2023-07-18 16:37:23 +08:00
hiyouga
ec166e736a fix #176
Former-commit-id: 2ae3445b0d28b4ed22ddbb2cfe09089ae0c23fe1
2023-07-18 16:36:24 +08:00
hiyouga
c85a6b83b3 fix webUI, fix #171 #177
Former-commit-id: 3459bb2d35162dbbef79cda05da08a56921aa276
2023-07-18 15:51:48 +08:00
hiyouga
a864a7b395 update webUI, fix #179
Former-commit-id: f9074fed5e22585679661588befcf266a79009f2
2023-07-18 15:35:17 +08:00
hiyouga
fd8c2d4aac tiny fix
Former-commit-id: bcdf5bb55651d639e9f57fd915268137156af9cd
2023-07-18 00:52:31 +08:00
hiyouga
baf2e4e825 a monkey patch for lora_target
Former-commit-id: 622f44a05b49b10571bd189ae3843683117ad77f
2023-07-18 00:31:40 +08:00
hiyouga
eac7f97337 release v0.1.0
Former-commit-id: 63c8d3a17cb18f0d8a8e37bfa147daf5bdd28ea9
2023-07-18 00:18:25 +08:00
hiyouga
c08ff734a7 fix #175
Former-commit-id: fd557ebb5e3ef2ca330b4d97731af43f4a5a5fc5
2023-07-17 18:07:17 +08:00
hiyouga
e9736b2ba0 fix saving custom code
Former-commit-id: 3f8f40bffd4f61fcc045f5f8a07420f3b46d0f7a
2023-07-16 18:04:41 +08:00
36 changed files with 1900 additions and 346 deletions

View File

@@ -10,7 +10,9 @@
## Changelog
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base`, `--padding_side right` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
@@ -125,14 +127,10 @@ cd LLaMA-Efficient-Tuning
pip install -r requirements.txt
```
### LLaMA Weights Preparation (optional)
1. Download the weights of the LLaMA models.
2. Convert them to HF format using the following command.
### All-in-one Web UI
```bash
python -m transformers.models.llama.convert_llama_weights_to_hf \
--input_dir path_to_llama_weights --model_size 7B --output_dir path_to_llama_model
python src/train_web.py
```
### (Continually) Pre-Training
@@ -275,10 +273,28 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
### API / CLI / Web Demo
### API Demo
```bash
python src/xxx_demo.py \
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
See `http://localhost:8000/docs` for API documentation.
### CLI Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```

View File

@@ -3,14 +3,14 @@ transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.19.0
peft>=0.3.0
trl>=0.4.4
trl>=0.4.7
sentencepiece
jieba
rouge-chinese
nltk
gradio>=3.36.0
uvicorn
pydantic==1.10.7
pydantic
fastapi
sse-starlette
matplotlib

View File

@@ -1,6 +1,7 @@
from llmtuner.api import create_app
from llmtuner.chat import ChatModel
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
from llmtuner.webui import create_ui
__version__ = "0.0.9"
__version__ = "0.1.0"

View File

@@ -1,3 +1,4 @@
import json
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
@@ -9,6 +10,8 @@ from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
ModelCard,
ModelList,
ChatMessage,
@@ -48,12 +51,12 @@ def create_app():
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != "user":
if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content
else:
prefix = None
@@ -61,7 +64,7 @@ def create_app():
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
@@ -80,20 +83,20 @@ def create_app():
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason=Finish.STOP
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
@@ -106,16 +109,16 @@ def create_app():
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
yield "[DONE]"
return app

View File

@@ -1,6 +1,18 @@
import time
from enum import Enum
from pydantic import BaseModel, Field
from typing import List, Literal, Optional
from typing import List, Optional
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
class ModelCard(BaseModel):
@@ -19,12 +31,12 @@ class ModelList(BaseModel):
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
role: Role
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
role: Optional[Role] = None
content: Optional[str] = None
@@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel):
@@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion"]
object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
@@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion.chunk"]
object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@@ -1,3 +1,4 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
@@ -41,10 +42,10 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
temperature=temperature if temperature else gen_kwargs["temperature"],
top_p=top_p if top_p else gen_kwargs["top_p"],
top_k=top_k if top_k else gen_kwargs["top_k"],
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
))
@@ -58,6 +59,7 @@ class ChatModel:
return gen_kwargs, prompt_length
@torch.inference_mode()
def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]:
@@ -68,6 +70,7 @@ class ChatModel:
response_length = len(outputs)
return response, (prompt_length, response_length)
@torch.inference_mode()
def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]:

View File

@@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
r"""
Event called after logging the last logs.
"""
if not state.is_world_process_zero:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time

View File

@@ -5,3 +5,36 @@ VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
METHODS = ["full", "freeze", "lora"]
SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
}
DEFAULT_MODULE = { # will be deprecated
"LLaMA": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"InternLM": "q_proj,v_proj"
}

View File

@@ -2,6 +2,20 @@ import sys
import logging
class LoggerHandler(logging.Handler):
def __init__(self):
super().__init__()
self.log = ""
def emit(self, record):
if record.name == "httpx":
return
log_entry = self.format(record)
self.log += log_entry
self.log += "\n\n"
def get_logger(name: str) -> logging.Logger:
formatter = logging.Formatter(

View File

@@ -1,4 +1,5 @@
import os
import math
import json
import matplotlib.pyplot as plt
from typing import List, Optional
@@ -10,12 +11,13 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)

View File

@@ -1,141 +1,29 @@
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
@dataclass
class Format:
prefix: str
prompt: str
sep: str
use_history: bool
templates: Dict[str, Format] = {}
@dataclass
class Template:
name: str
def __post_init__(self):
if self.name == "vanilla":
r"""
Supports language model inference without histories.
"""
self._register_template(
prefix="",
prompt="{query}",
sep="",
use_history=False
)
elif self.name == "default":
r"""
Default template.
"""
self._register_template(
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
elif self.name == "alpaca":
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
self._register_template(
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
)
elif self.name == "vicuna":
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
self._register_template(
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
use_history=True
)
elif self.name == "belle":
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
self._register_template(
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
)
elif self.name == "linly":
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
self._register_template(
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
)
elif self.name == "billa":
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
self._register_template(
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
elif self.name == "ziya":
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
self._register_template(
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
)
elif self.name == "aquila":
r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
self._register_template(
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
)
elif self.name == "intern":
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
self._register_template(
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
)
elif self.name == "baichuan":
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
self._register_template(
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="",
use_history=True
)
if self.name in templates:
self.prefix = templates[self.name].prefix
self.prompt = templates[self.name].prompt
self.sep = templates[self.name].sep
self.use_history = templates[self.name].use_history
else:
raise ValueError("Template {} does not exist.".format(self.name))
@@ -155,14 +43,6 @@ class Template:
"""
return self._format_example(query, history, prefix) + [resp]
def _register_template(
self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True
) -> None:
self.prefix = prefix
self.prompt = prompt
self.sep = sep
self.use_history = use_history
def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
@@ -179,3 +59,163 @@ class Template:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp)
return convs[:-1] # drop last
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Format(
prefix=prefix,
prompt=prompt,
sep=sep,
use_history=use_history
)
r"""
Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix="",
prompt="{query}",
sep="",
use_history=False
)
r"""
Default template.
"""
register_template(
name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
register_template(
name="alpaca",
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
register_template(
name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
use_history=True
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
)
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
)
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_template(
name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix="",
prompt=" <reserved_102> {query} <reserved_103> ",
sep="</s>",
use_history=True
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True
)

View File

@@ -11,7 +11,7 @@ from transformers import (
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
@@ -28,7 +28,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer(
@@ -36,7 +36,7 @@ def load_model_and_tokenizer(
finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
r"""
Loads pretrained model and tokenizer.
@@ -113,12 +113,12 @@ def load_model_and_tokenizer(
)
# Register auto class to save the custom code files.
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
tokenizer.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model

View File

@@ -25,7 +25,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
training_args: Seq2SeqTrainingArguments,
@@ -46,12 +45,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)
len_dataloader = len(self.dataloader)
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
max_steps = math.ceil(num_train_epochs * len_dataloader)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
@@ -62,9 +62,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
@@ -77,7 +77,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
@@ -87,59 +87,45 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
batch = next(dataiter)
steps_trained += 1
for _ in range(self.config.gradient_accumulation_steps):
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
batch = next(dataiter)
steps_trained += 1
# Get responses
query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
queries, responses = [], []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Get response from model
query_tensors: torch.Tensor = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
queries: List[torch.Tensor] = []
responses: List[torch.Tensor] = []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
if response_length < 2: # make response have at least 2 tokens
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
else:
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default")
# Run PPO step
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
# Run PPO step
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = {
"loss": round(loss_meter.avg, 4),
"reward": round(reward_meter.avg, 4),
"learning_rate": stats["ppo/learning_rate"],
"epoch": round(step / num_steps_per_epoch, 2)
}
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
)
print(logs)
logs["step"] = step
self.state.log_history.append(logs)
@@ -150,9 +136,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
if self.control.should_training_stop:
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
@torch.no_grad()
def generate(
self,

View File

@@ -4,7 +4,8 @@
import math
from trl import PPOConfig
from torch.optim import AdamW
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
@@ -19,7 +20,8 @@ def run_ppo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
@@ -30,7 +32,7 @@ def run_ppo(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
@@ -50,7 +52,7 @@ def run_ppo(
ppo_trainer = PPOPeftTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[LogCallback()],
callbacks=callbacks,
config=ppo_config,
model=model,
ref_model=None,

View File

@@ -2,7 +2,8 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from transformers import Seq2SeqTrainingArguments
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
@@ -18,7 +19,8 @@ def run_rm(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
@@ -44,7 +46,7 @@ def run_rm(
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
callbacks=callbacks,
compute_metrics=compute_accuracy,
**trainer_kwargs
)

View File

@@ -23,7 +23,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -47,5 +47,6 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@@ -32,17 +32,40 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Subclass and override to inject custom behavior.
"""
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
else:
inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None
generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
return (loss, generated_tokens, labels)
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
Should only be called when predict_with_generate=True.
"""
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor
def save_predictions(
self,
predict_results: PredictionOutput

View File

@@ -0,0 +1 @@
from llmtuner.webui.interface import create_ui

View File

@@ -0,0 +1,88 @@
import os
from typing import List, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self, *args):
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if len(args) != 0:
super().__init__(*args)
def load_model(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
return
if not model_name:
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else:
checkpoint_dir = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix
)
super().__init__(*get_infer_args(args))
yield ALERTS["info_loaded"][lang]
def unload_model(self, lang: str):
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None
torch_gc()
yield ALERTS["info_unloaded"][lang]
def predict(
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
max_new_tokens: int,
top_p: float,
temperature: float
):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = [query, response]
yield chatbot, new_history

View File

@@ -0,0 +1,75 @@
import json
import os
from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
def get_save_dir(model_name: str) -> str:
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"last_model": "", "path_dict": {}}
def save_config(model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False)
def get_model_path(model_name: str) -> str:
user_config = load_config()
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)
def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {}
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
return gr.update(value=[], choices=list(dataset_info.keys()))

View File

@@ -0,0 +1,4 @@
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab

View File

@@ -0,0 +1,55 @@
from typing import Dict, Optional, Tuple
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel,
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
query = gr.Textbox(show_label=False, lines=8)
with gr.Column(min_width=32, scale=1):
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
clear_btn = gr.Button()
max_new_tokens = gr.Slider(
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True
)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
history = gr.State([])
submit_btn.click(
chat_model.predict,
[chatbot, query, history, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
lambda: gr.update(value=""), outputs=[query]
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
)

View File

@@ -0,0 +1,19 @@
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from typing import Tuple
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)
with gr.Row():
preview_samples = gr.JSON(interactive=False)
close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
return preview_box, preview_count, preview_samples, close_btn

View File

@@ -0,0 +1,73 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
predict = gr.Checkbox(value=True)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
output_box = gr.Markdown()
start_btn.click(
runner.run_eval,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
max_samples,
batch_size,
predict
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=max_samples,
batch_size=batch_size,
predict=predict,
start_btn=start_btn,
stop_btn=stop_btn,
output_box=output_box
)

View File

@@ -0,0 +1,49 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
info_box = gr.Markdown()
chat_model = WebChatModel()
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
load_btn.click(
chat_model.load_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"]
],
[info_box]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
unload_btn.click(
chat_model.unload_model, [top_elems["lang"]], [info_box]
).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
return dict(
info_box=info_box,
load_btn=load_btn,
unload_btn=unload_btn,
**chat_elems
)

View File

@@ -0,0 +1,115 @@
from typing import Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
)
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
fp16 = gr.Checkbox(value=True)
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True)
output_box = gr.Markdown()
with gr.Column(scale=1):
loss_viewer = gr.Plot()
start_btn.click(
runner.run_train,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
dev_ratio,
fp16,
logging_steps,
save_steps,
output_dir
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)
output_box.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
dev_ratio=dev_ratio,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
output_box=output_box,
loss_viewer=loss_viewer
)

View File

@@ -0,0 +1,55 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize
def create_top() -> Dict[str, Component]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
source_prefix = gr.Textbox(scale=4)
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
return dict(
lang=lang,
model_name=model_name,
model_path=model_path,
finetuning_type=finetuning_type,
template=template,
checkpoints=checkpoints,
refresh_btn=refresh_btn,
quantization_bit=quantization_bit,
source_prefix=source_prefix
)

18
src/llmtuner/webui/css.py Normal file
View File

@@ -0,0 +1,18 @@
CSS = r"""
.modal-box {
position: fixed !important;
top: 50%;
left: 50%;
transform: translate(-50%, -50%); /* center horizontally */
max-width: 1000px;
max-height: 750px;
overflow-y: scroll !important;
background-color: var(--input-background-fill);
border: 2px solid black !important;
z-index: 1000;
}
.dark .modal-box {
border: 2px solid white !important;
}
"""

View File

@@ -0,0 +1,54 @@
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
create_top,
create_sft_tab,
create_eval_tab,
create_infer_tab
)
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def create_ui() -> gr.Blocks:
runner = Runner()
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
with gr.Tab("SFT"):
sft_elems = create_sft_tab(top_elems, runner)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
with gr.Tab("Inference"):
infer_elems = create_infer_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
manager = Manager(elem_list)
demo.load(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
top_elems["lang"].change(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)

View File

@@ -0,0 +1,434 @@
LOCALES = {
"lang": {
"en": {
"label": "Lang"
},
"zh": {
"label": "语言"
}
},
"model_name": {
"en": {
"label": "Model name"
},
"zh": {
"label": "模型名称"
}
},
"model_path": {
"en": {
"label": "Model path",
"info": "Path to pretrained model or model identifier from Hugging Face."
},
"zh": {
"label": "模型路径",
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"checkpoints": {
"en": {
"label": "Checkpoints"
},
"zh": {
"label": "模型断点"
}
},
"refresh_btn": {
"en": {
"value": "Refresh checkpoints"
},
"zh": {
"value": "刷新断点"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit (optional)",
"info": "Enable 4/8-bit model quantization."
},
"zh": {
"label": "量化等级(非必填)",
"info": "启用 4/8 比特模型量化。"
}
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts."
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板"
}
},
"source_prefix": {
"en": {
"label": "Source prefix (optional)",
"info": "A sequence used as the prefix of each samples."
},
"zh": {
"label": "前缀序列(非必填)",
"info": "作为每个输入样本前缀的序列"
}
},
"dataset_dir": {
"en": {
"label": "Data dir",
"info": "Path of the data directory."
},
"zh": {
"label": "数据路径",
"info": "数据文件夹的路径。"
}
},
"dataset": {
"en": {
"label": "Dataset"
},
"zh": {
"label": "数据集"
}
},
"preview_btn": {
"en": {
"value": "Preview"
},
"zh": {
"value": "预览"
}
},
"preview_count": {
"en": {
"label": "Count"
},
"zh": {
"label": "数量"
}
},
"preview_samples": {
"en": {
"label": "Samples"
},
"zh": {
"label": "样例"
}
},
"close_btn": {
"en": {
"value": "Close"
},
"zh": {
"value": "关闭"
}
},
"max_source_length": {
"en": {
"label": "Max source length",
"info": "Max tokens in source sequence."
},
"zh": {
"label": "输入序列最大长度",
"info": "输入序列分词后的最大长度。"
}
},
"max_target_length": {
"en": {
"label": "Max target length",
"info": "Max tokens in target sequence."
},
"zh": {
"label": "输出序列最大长度",
"info": "输出序列分词后的最大长度。"
}
},
"learning_rate": {
"en": {
"label": "Learning rate",
"info": "Initial learning rate for AdamW."
},
"zh": {
"label": "学习率",
"info": "AdamW 优化器的初始学习率。"
}
},
"num_train_epochs": {
"en": {
"label": "Epochs",
"info": "Total number of training epochs to perform."
},
"zh": {
"label": "训练轮数",
"info": "需要执行的训练总轮数。"
}
},
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"gradient_accumulation_steps": {
"en": {
"label": "Gradient accumulation",
"info": "Number of gradient accumulation steps."
},
"zh": {
"label": "梯度累积",
"info": "梯度累积的步数。"
}
},
"lr_scheduler_type": {
"en": {
"label": "LR Scheduler",
"info": "Name of learning rate scheduler.",
},
"zh": {
"label": "学习率调节器",
"info": "采用的学习率调节器名称。"
}
},
"dev_ratio": {
"en": {
"label": "Dev ratio",
"info": "Proportion of data in the dev set."
},
"zh": {
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
}
},
"fp16": {
"en": {
"label": "fp16",
"info": "Whether to use fp16 mixed precision training."
},
"zh": {
"label": "fp16",
"info": "是否启用 FP16 混合精度训练。"
}
},
"logging_steps": {
"en": {
"label": "Logging steps",
"info": "Number of update steps between two logs."
},
"zh": {
"label": "日志间隔",
"info": "每两次日志输出间的更新步数。"
}
},
"save_steps": {
"en": {
"label": "Save steps",
"info": "Number of updates steps between two checkpoints."
},
"zh": {
"label": "保存间隔",
"info": "每两次断点保存间的更新步数。"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_dir": {
"en": {
"label": "Checkpoint name",
"info": "Directory to save checkpoint."
},
"zh": {
"label": "断点名称",
"info": "保存模型断点的文件夹名称。"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"loss_viewer": {
"en": {
"label": "Loss"
},
"zh": {
"label": "损失"
}
},
"predict": {
"en": {
"label": "Save predictions"
},
"zh": {
"label": "保存预测结果"
}
},
"load_btn": {
"en": {
"value": "Load model"
},
"zh": {
"value": "加载模型"
}
},
"unload_btn": {
"en": {
"value": "Unload model"
},
"zh": {
"value": "卸载模型"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"query": {
"en": {
"placeholder": "Input..."
},
"zh": {
"placeholder": "输入..."
}
},
"submit_btn": {
"en": {
"value": "Submit"
},
"zh": {
"value": "提交"
}
},
"clear_btn": {
"en": {
"value": "Clear history"
},
"zh": {
"value": "清空历史"
}
},
"max_length": {
"en": {
"label": "Maximum length"
},
"zh": {
"label": "最大长度"
}
},
"max_new_tokens": {
"en": {
"label": "Maximum new tokens"
},
"zh": {
"label": "最大生成长度"
}
},
"top_p": {
"en": {
"label": "Top-p"
},
"zh": {
"label": "Top-p 采样值"
}
},
"temperature": {
"en": {
"label": "Temperature"
},
"zh": {
"label": "温度系数"
}
}
}
ALERTS = {
"err_conflict": {
"en": "A process is in running, please abort it firstly.",
"zh": "任务已存在,请先中断训练。"
},
"err_exists": {
"en": "You have loaded a model, please unload it first.",
"zh": "模型已存在,请先卸载模型。"
},
"err_no_model": {
"en": "Please select a model.",
"zh": "请选择模型。"
},
"err_no_path": {
"en": "Model not found.",
"zh": "模型未找到。"
},
"err_no_dataset": {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
},
"info_aborted": {
"en": "Ready.",
"zh": "准备就绪。"
},
"info_finished": {
"en": "Finished.",
"zh": "训练完毕。"
},
"info_loading": {
"en": "Loading model...",
"zh": "加载中……"
},
"info_unloading": {
"en": "Unloading model...",
"zh": "卸载中……"
},
"info_loaded": {
"en": "Model loaded, now you can chat with your model!",
"zh": "模型已加载,可以开始聊天了!"
},
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
}
}

View File

@@ -0,0 +1,35 @@
import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.utils import get_time
class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]:
refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()}
}
user_config = load_config()
if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]:
update_dict = {}
refresh_dict = self.gen_refresh()
for elems in self.elem_list:
for name, component in elems.items():
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
return update_dict

View File

@@ -0,0 +1,224 @@
import logging
import os
import threading
import time
import transformers
from typing import List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
class Runner:
def __init__(self):
self.aborted = False
self.running = False
def set_abort(self):
self.aborted = True
self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
if not model_name:
return None, ALERTS["err_no_model"][lang], None, None
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
return None, ALERTS["err_no_path"][lang], None, None
if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None
self.aborted = False
self.running = True
logger_handler = LoggerHandler()
logger_handler.setLevel(logging.INFO)
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
self.running = False
torch_gc()
if self.aborted:
return ALERTS["info_aborted"][lang]
else:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
def run_train(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
dev_ratio: float,
fp16: bool,
logging_steps: int,
save_steps: int,
output_dir: str
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else:
checkpoint_dir = None
args = dict(
model_name_or_path=model_name_or_path,
do_train=True,
overwrite_cache=True,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
)
if dev_ratio > 1e-6:
args["dev_ratio"] = dev_ratio
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang)
def run_eval(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
max_samples: str,
batch_size: int,
predict: bool
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict(
model_name_or_path=model_name_or_path,
do_eval=True,
overwrite_cache=True,
predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size,
output_dir=output_dir
)
if predict:
args.pop("do_eval", None)
args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))

View File

@@ -0,0 +1,85 @@
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Any, Dict, Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
def format_info(log: str, tracker: dict) -> str:
info = log
if "current_steps" in tracker:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
)
return info
def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
):
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
data = json.load(f)
return len(data), data[:2], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora":
return gr.update(value="", interactive=False)
else:
return gr.update(interactive=True)
def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
return None
plt.close("all")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0:
return None
ax.plot(steps, losses, alpha=0.4, label="original")
ax.plot(steps, smooth(losses), label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig

11
src/train_web.py Normal file
View File

@@ -0,0 +1,11 @@
from llmtuner import create_ui
def main():
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
if __name__ == "__main__":
main()

View File

@@ -3,93 +3,42 @@
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
from transformers.utils.versions import require_version
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
from llmtuner import get_infer_args
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
model_args, data_args, finetuning_args, generating_args = get_infer_args()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
def main():
chat_model = WebChatModel(*get_infer_args())
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
lang.change(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
chatbot.append((query, ""))
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
"input_ids": input_ids,
"top_p": top_p,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"logits_processor": get_logits_processor(),
"streamer": streamer
})
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = (query, response)
yield chatbot, new_history
def reset_user_input():
return gr.update(value="")
def reset_state():
return [], []
with gr.Blocks() as demo:
gr.HTML("""
<h1 align="center">
<a href="https://github.com/hiyouga/LLaMA-Efficient-Tuning" target="_blank">
LLaMA Efficient Tuning
</a>
</h1>
""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0,
label="Maximum new tokens", interactive=True)
top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01,
label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)
if __name__ == "__main__":
main()

View File

@@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel):
if isinstance(module, BaichuanModel):
module.gradient_checkpointing = value
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_baichuan_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
class BaichuanModel(BaichuanPreTrainedModel):
@@ -318,9 +357,9 @@ class BaichuanModel(BaichuanPreTrainedModel):
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
self.embed_tokens = value
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
return build_alibi_tensor(attention_mask, num_heads, dtype)
@@ -468,7 +507,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class BaichuanForCausalLM(BaichuanPreTrainedModel):
@@ -498,7 +537,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
@@ -528,7 +567,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
@@ -559,11 +598,20 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
if past_key_values:
input_ids = input_ids[:, -1:]
# the cache may be in the standard format (e.g. in contrastive search)
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_baichuan_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
@@ -571,21 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
return tuple(
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
for layer_past in past_key_values
)
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return self._convert_to_baichuan_cache(reordered_past)
def quantize(self, bits: int):
try:
@@ -594,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
raise ImportError(
f"Needs QLinear to run quantize."
)
for layer in self.model.layers:
layer.self_attn.W_pack = QLinear(
bits=bits,
@@ -621,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
weight=layer.mlp.up_proj.weight,
bias = None,
)
return self
return self
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens