Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fe3cce5a3 | ||
|
|
09f165d442 | ||
|
|
60aea7521b | ||
|
|
29545d0e5e | ||
|
|
4a14099cfd | ||
|
|
b052574ddf | ||
|
|
5ea6a7c6d6 | ||
|
|
8ca196d51f | ||
|
|
5f572cbd77 | ||
|
|
679bd3ab30 | ||
|
|
da3d59fada | ||
|
|
835d27151d | ||
|
|
f1d7228a74 | ||
|
|
72bbd5bdef | ||
|
|
ad9d866547 | ||
|
|
a1ec668b70 | ||
|
|
389687a56d | ||
|
|
97280c73b9 | ||
|
|
f3c622b665 | ||
|
|
d71e8d8dbf | ||
|
|
02c2089ac8 | ||
|
|
07ad28a053 | ||
|
|
d323ccc3ec | ||
|
|
4738d002c7 | ||
|
|
ec099b0586 | ||
|
|
a51253fea2 | ||
|
|
304ec9ec6a | ||
|
|
8547085615 | ||
|
|
14b139ecb5 | ||
|
|
7b45f5068f | ||
|
|
99ceee840e | ||
|
|
8ed68301e3 | ||
|
|
664267e050 | ||
|
|
7ef8f46591 | ||
|
|
6933c1fed2 | ||
|
|
9d125bf533 | ||
|
|
08d5340bd8 | ||
|
|
0e6f4f981e | ||
|
|
670ee3934f | ||
|
|
569860d7ac | ||
|
|
953a562ec1 |
7
.gitignore
vendored
7
.gitignore
vendored
@@ -157,4 +157,9 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# custom .gitignore
|
||||
user.config
|
||||
saves/
|
||||
cache/
|
||||
|
||||
48
README.md
48
README.md
@@ -1,4 +1,4 @@
|
||||
# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
|
||||

|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
@@ -44,15 +44,23 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||

|
||||
|
||||
<details><summary>Definitions</summary>
|
||||
|
||||
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
||||
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
||||
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
||||
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
|
||||
|
||||
</details>
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage.
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||
|
||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||
@@ -77,6 +85,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
|
||||
</details>
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Default module | Template |
|
||||
@@ -92,7 +102,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -156,6 +166,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
@@ -171,6 +182,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -192,7 +204,15 @@ huggingface-cli login
|
||||
- gradio and matplotlib (used in web UI)
|
||||
- uvicorn, fastapi and sse-starlette (used in API)
|
||||
|
||||
And **powerful GPUs**!
|
||||
### Hardware Requirement
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 65B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ |
|
||||
| Full | 16 | 140GB | 240GB | 520GB | 1200GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -219,6 +239,28 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### Use ModelScope Models (optional)
|
||||
|
||||
If you have trouble with downloading models from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||
|
||||
```bash
|
||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||
... # arguments (same as above)
|
||||
```
|
||||
|
||||
LLaMA Board also supports using the models on the ModelScope Hub.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||
```
|
||||
|
||||
### Train on a single GPU
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
||||
52
README_zh.md
52
README_zh.md
@@ -1,4 +1,4 @@
|
||||
# LLaMA Factory: 轻松的大模型训练与评估
|
||||

|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
@@ -31,7 +31,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [模型](#模型)
|
||||
- [训练方法](#训练方法)
|
||||
- [数据集](#数据集)
|
||||
- [软件依赖](#软件依赖)
|
||||
- [软硬件依赖](#软硬件依赖)
|
||||
- [如何使用](#如何使用)
|
||||
- [使用了 LLaMA Factory 的项目](#使用了-llama-factory-的项目)
|
||||
- [协议](#协议)
|
||||
@@ -44,15 +44,23 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||

|
||||
|
||||
<details><summary>变量定义</summary>
|
||||
|
||||
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024)
|
||||
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
|
||||
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
|
||||
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`。
|
||||
|
||||
</details>
|
||||
|
||||
## 更新日志
|
||||
|
||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||
|
||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||
@@ -77,6 +85,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
|
||||
</details>
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
@@ -92,7 +102,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -156,6 +166,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
@@ -171,6 +182,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -183,7 +195,7 @@ pip install --upgrade huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## 软件依赖
|
||||
## 软硬件依赖
|
||||
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
@@ -192,7 +204,15 @@ huggingface-cli login
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
|
||||
以及 **强而有力的 GPU**!
|
||||
### 硬件依赖
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ |
|
||||
| 全参数 | 16 | 140GB | 240GB | 520GB | 1200GB |
|
||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
||||
|
||||
## 如何使用
|
||||
|
||||
@@ -219,6 +239,28 @@ pip install -r requirements.txt
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### 使用魔搭社区(可跳过)
|
||||
|
||||
如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
|
||||
```bash
|
||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
```
|
||||
|
||||
接着即可通过指定模型名称来训练对应的模型。(在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||
... # 参数同上
|
||||
```
|
||||
|
||||
LLaMA Board 同样支持魔搭社区的模型下载。
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||
```
|
||||
|
||||
### 单 GPU 训练
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
||||
@@ -7,4 +7,4 @@ from llmtuner.train import export_model, run_exp
|
||||
from llmtuner.webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.3.2"
|
||||
__version__ = "0.3.3"
|
||||
|
||||
@@ -15,7 +15,9 @@ from llmtuner.api.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage
|
||||
ChatCompletionResponseUsage,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse
|
||||
)
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
@@ -68,6 +70,9 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if not chat_model.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
@@ -156,6 +161,17 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
yield to_json(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -81,3 +81,16 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[str]
|
||||
max_length: Optional[int] = None
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Optional[str] = "scoreeval-default"
|
||||
object: Optional[str] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import tiktoken
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||
from threading import Thread
|
||||
@@ -22,8 +23,11 @@ class ChatModel:
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.can_generate = (finetuning_args.stage == "sft")
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.system_prompt = data_args.system_prompt
|
||||
@@ -130,3 +134,41 @@ class ChatModel:
|
||||
thread.start()
|
||||
|
||||
yield from streamer
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs
|
||||
) -> List[float]:
|
||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=True)
|
||||
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||
|
||||
inputs = self.tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||
pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
**kwargs
|
||||
).to(device)
|
||||
|
||||
input_ids: torch.Tensor = inputs["input_ids"]
|
||||
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
scores = []
|
||||
for i in range(input_ids.size(0)):
|
||||
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
scores.append(values[i, end_index].nan_to_num().item())
|
||||
|
||||
return scores
|
||||
|
||||
@@ -408,18 +408,31 @@ register_template(
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
"User: {{query}}\n\nAssistant:"
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer."
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
sep=[
|
||||
"\n",
|
||||
{"token": "<|EOT|>"},
|
||||
"\n\n"
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|EOT|>"
|
||||
@@ -637,6 +650,23 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xuanyuan",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}} Assistant:"
|
||||
],
|
||||
system=(
|
||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xverse",
|
||||
prefix=[
|
||||
@@ -682,6 +712,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yi",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"<|im_end|>\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="zephyr",
|
||||
prefix=[
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
@@ -18,6 +19,16 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir)
|
||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -25,25 +36,17 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir)
|
||||
_save_model_with_valuehead(
|
||||
model=unwrap_model(kwargs.pop("model")),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||
model.pretrained_model.config.save_pretrained(args.output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(args.output_dir)
|
||||
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Dict, Optional
|
||||
|
||||
@@ -28,9 +29,13 @@ TRAINING_STAGES = {
|
||||
"Pre-Training": "pt"
|
||||
}
|
||||
|
||||
class DownloadSource(str, Enum):
|
||||
DEFAULT = "hf"
|
||||
MODELSCOPE = "ms"
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, str],
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None
|
||||
) -> None:
|
||||
@@ -49,9 +54,18 @@ def register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
|
||||
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
|
||||
"Baichuan-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
|
||||
},
|
||||
"Baichuan-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
|
||||
},
|
||||
"Baichuan-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
|
||||
}
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan"
|
||||
@@ -60,10 +74,22 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
|
||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
|
||||
"Baichuan2-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
|
||||
},
|
||||
"Baichuan2-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
|
||||
},
|
||||
"Baichuan2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
|
||||
},
|
||||
"Baichuan2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
|
||||
}
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan2"
|
||||
@@ -72,9 +98,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1"
|
||||
"BLOOM-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
|
||||
},
|
||||
"BLOOM-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
|
||||
},
|
||||
"BLOOM-7B1": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
|
||||
}
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
@@ -82,9 +117,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
|
||||
"BLOOMZ-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
|
||||
},
|
||||
"BLOOMZ-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
|
||||
},
|
||||
"BLOOMZ-7B1-mt": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
|
||||
}
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
@@ -92,8 +136,14 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
|
||||
"BlueLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
|
||||
},
|
||||
"BlueLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
|
||||
}
|
||||
},
|
||||
template="bluelm"
|
||||
)
|
||||
@@ -101,7 +151,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||
"ChatGLM2-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm2"
|
||||
@@ -110,8 +163,14 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
|
||||
"ChatGLM3-6B-Base": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
|
||||
},
|
||||
"ChatGLM3-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm3"
|
||||
@@ -120,12 +179,30 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
||||
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
|
||||
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
|
||||
"ChineseLLaMA2-1.3B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
|
||||
},
|
||||
"ChineseLLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
|
||||
},
|
||||
"ChineseLLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
|
||||
},
|
||||
"ChineseLLaMA2-1.3B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
|
||||
},
|
||||
"ChineseLLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
|
||||
},
|
||||
"ChineseLLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
|
||||
}
|
||||
},
|
||||
template="llama2_zh"
|
||||
)
|
||||
@@ -133,12 +210,76 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": "tiiuae/falcon-7b",
|
||||
"Falcon-40B": "tiiuae/falcon-40b",
|
||||
"Falcon-180B": "tiiuae/falcon-180B",
|
||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
|
||||
"DeepseekLLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
|
||||
},
|
||||
"DeepseekLLM-67B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
|
||||
},
|
||||
"DeepseekLLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
|
||||
},
|
||||
"DeepseekLLM-67B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
|
||||
}
|
||||
},
|
||||
template="deepseek"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepseekCoder-6.7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
|
||||
},
|
||||
"DeepseekCoder-33B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
|
||||
},
|
||||
"DeepseekCoder-6.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
|
||||
},
|
||||
"DeepseekCoder-33B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
|
||||
}
|
||||
},
|
||||
template="deepseekcoder"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
|
||||
},
|
||||
"Falcon-40B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
|
||||
},
|
||||
"Falcon-180B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
|
||||
},
|
||||
"Falcon-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
|
||||
},
|
||||
"Falcon-40B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
|
||||
},
|
||||
"Falcon-180B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon"
|
||||
@@ -147,10 +288,22 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": "internlm/internlm-7b",
|
||||
"InternLM-20B": "internlm/internlm-20b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
|
||||
"InternLM-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
|
||||
},
|
||||
"InternLM-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
|
||||
},
|
||||
"InternLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
|
||||
},
|
||||
"InternLM-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
|
||||
}
|
||||
},
|
||||
template="intern"
|
||||
)
|
||||
@@ -158,7 +311,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
||||
"LingoWhale-8B": {
|
||||
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
|
||||
}
|
||||
},
|
||||
module="qkv_proj"
|
||||
)
|
||||
@@ -166,22 +322,52 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": "huggyllama/llama-7b",
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
"LLaMA-30B": "huggyllama/llama-30b",
|
||||
"LLaMA-65B": "huggyllama/llama-65b"
|
||||
"LLaMA-7B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
|
||||
},
|
||||
"LLaMA-13B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
|
||||
},
|
||||
"LLaMA-30B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
|
||||
},
|
||||
"LLaMA-65B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
|
||||
"LLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
|
||||
},
|
||||
"LLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
|
||||
},
|
||||
"LLaMA2-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
|
||||
},
|
||||
"LLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
|
||||
},
|
||||
"LLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
|
||||
},
|
||||
"LLaMA2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
|
||||
}
|
||||
},
|
||||
template="llama2"
|
||||
)
|
||||
@@ -189,8 +375,14 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
||||
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
"Mistral-7B": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
|
||||
},
|
||||
"Mistral-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
||||
}
|
||||
},
|
||||
template="mistral"
|
||||
)
|
||||
@@ -198,7 +390,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
|
||||
"OpenChat3.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
|
||||
}
|
||||
},
|
||||
template="openchat"
|
||||
)
|
||||
@@ -206,7 +401,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||
"Phi1.5-1.3B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
||||
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
|
||||
}
|
||||
},
|
||||
module="Wqkv"
|
||||
)
|
||||
@@ -214,10 +412,70 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
|
||||
"Qwen-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
|
||||
},
|
||||
"Qwen-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
|
||||
},
|
||||
"Qwen-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
|
||||
},
|
||||
"Qwen-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
|
||||
},
|
||||
"Qwen-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
|
||||
},
|
||||
"Qwen-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
|
||||
},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
|
||||
},
|
||||
"Qwen-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
|
||||
},
|
||||
"Qwen-1.8B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
|
||||
},
|
||||
"Qwen-1.8B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
|
||||
},
|
||||
"Qwen-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
|
||||
},
|
||||
"Qwen-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
|
||||
},
|
||||
"Qwen-14B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
|
||||
},
|
||||
"Qwen-14B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
|
||||
},
|
||||
"Qwen-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
|
||||
},
|
||||
"Qwen-72B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
|
||||
}
|
||||
},
|
||||
module="c_attn",
|
||||
template="qwen"
|
||||
@@ -226,15 +484,24 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
||||
"Skywork-13B-Base": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
|
||||
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
|
||||
"Vicuna1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
|
||||
},
|
||||
"Vicuna1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
|
||||
}
|
||||
},
|
||||
template="vicuna"
|
||||
)
|
||||
@@ -242,11 +509,45 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||
"XVERSE-65B": "xverse/XVERSE-65B",
|
||||
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
|
||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
|
||||
"XuanYuan-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"
|
||||
},
|
||||
"XuanYuan-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
|
||||
},
|
||||
"XuanYuan-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
|
||||
},
|
||||
"XuanYuan-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
|
||||
}
|
||||
},
|
||||
template="xuanyuan"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
|
||||
},
|
||||
"XVERSE-13B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
|
||||
},
|
||||
"XVERSE-65B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
|
||||
},
|
||||
"XVERSE-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
|
||||
},
|
||||
"XVERSE-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
|
||||
}
|
||||
},
|
||||
template="xverse"
|
||||
)
|
||||
@@ -254,8 +555,14 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yayi-7B": "wenge-research/yayi-7b-llama2",
|
||||
"Yayi-13B": "wenge-research/yayi-13b-llama2"
|
||||
"Yayi-7B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
|
||||
},
|
||||
"Yayi-13B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
|
||||
}
|
||||
},
|
||||
template="yayi"
|
||||
)
|
||||
@@ -263,16 +570,37 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": "01-ai/Yi-6B",
|
||||
"Yi-34B": "01-ai/Yi-34B"
|
||||
}
|
||||
"Yi-6B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B"
|
||||
},
|
||||
"Yi-34B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
|
||||
},
|
||||
"Yi-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
|
||||
}
|
||||
},
|
||||
template="yi"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
|
||||
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
|
||||
"Zephyr-7B-Alpha-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
|
||||
},
|
||||
"Zephyr-7B-Beta-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
|
||||
}
|
||||
},
|
||||
template="zephyr"
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ except ImportError:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
@@ -67,15 +68,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_current_device() -> str:
|
||||
import accelerate
|
||||
dummy_accelerator = accelerate.Accelerator()
|
||||
if accelerate.utils.is_xpu_available():
|
||||
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
||||
else:
|
||||
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
@@ -116,3 +108,23 @@ def torch_gc() -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||
return
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
model_args.model_name_or_path = snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=revision,
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||
|
||||
@@ -18,6 +18,7 @@ _flash_attn2_available = is_package_available("flash_attn") and get_package_vers
|
||||
_jieba_available = is_package_available("jieba")
|
||||
_matplotlib_available = is_package_available("matplotlib")
|
||||
_nltk_available = is_package_available("nltk")
|
||||
_requests_available = is_package_available("requests")
|
||||
_rouge_available = is_package_available("rouge_chinese")
|
||||
_starlette_available = is_package_available("sse_starlette")
|
||||
_uvicorn_available = is_package_available("uvicorn")
|
||||
@@ -43,6 +44,10 @@ def is_nltk_available():
|
||||
return _nltk_available
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _requests_available
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return _rouge_available
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ from typing import List, Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
@@ -130,11 +133,11 @@ class DataArguments:
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
except Exception as err:
|
||||
if self.dataset is not None:
|
||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
@@ -147,7 +150,7 @@ class DataArguments:
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for i, name in enumerate(dataset_names):
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
|
||||
@@ -118,9 +118,9 @@ class RLHFArguments:
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."}
|
||||
)
|
||||
reward_model_type: Optional[Literal["lora", "full"]] = field(
|
||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
|
||||
)
|
||||
|
||||
|
||||
@@ -149,6 +149,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
export_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
@@ -175,7 +179,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("Lora reward model only supports lora training.")
|
||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
||||
@@ -8,7 +8,8 @@ class ModelArguments:
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
||||
metadata={"help": "Path to pretrained model or model identifier from \
|
||||
huggingface.co/models or modelscope.cn/models."}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
|
||||
@@ -87,7 +87,7 @@ def init_adapter(
|
||||
|
||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
target_modules = find_all_linear_modules(model)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
@@ -21,8 +22,8 @@ try:
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
@@ -48,7 +49,7 @@ def load_model_and_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
add_valuehead: Optional[bool] = False
|
||||
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
@@ -56,6 +57,8 @@ def load_model_and_tokenizer(
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
try_download_model_from_ms(model_args)
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
@@ -144,6 +147,14 @@ def load_model_and_tokenizer(
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using gptq or awq)
|
||||
if getattr(config, "quantization_config", None):
|
||||
if model_args.quantization_bit is not None: # remove bnb quantization
|
||||
model_args.quantization_bit = None
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
quantization_config = getattr(config, "quantization_config", None)
|
||||
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
# Quantization configurations (using bitsandbytes library)
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
@@ -151,12 +162,10 @@ def load_model_and_tokenizer(
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
@@ -164,7 +173,7 @@ def load_model_and_tokenizer(
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load pre-trained models (without valuehead)
|
||||
@@ -196,10 +205,9 @@ def load_model_and_tokenizer(
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage in ["rm", "ppo"]:
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
||||
@@ -215,6 +223,9 @@ def load_model_and_tokenizer(
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
|
||||
@@ -22,10 +22,10 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
if getattr(model, "quantization_method", None): # already set on current device
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
@@ -42,18 +42,18 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
return model.cuda()
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
"""
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
quantization_method = getattr(model, "quantization_method", None)
|
||||
if quantization_method is None:
|
||||
linear_cls = torch.nn.Linear
|
||||
elif quantization_method == "bitsandbytes":
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
|
||||
@@ -25,11 +25,11 @@ def run_dpo(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ def run_dpo(
|
||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||
ref_model = model
|
||||
else:
|
||||
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
|
||||
@@ -3,10 +3,12 @@ import sys
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
@@ -14,7 +16,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@@ -55,17 +57,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||
self.accelerator.state, "deepspeed_plugin"
|
||||
)
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
if reward_model is not None:
|
||||
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||
self.accelerator.state, "deepspeed_plugin"
|
||||
)
|
||||
if is_deepspeed_enabled:
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||
@@ -198,7 +200,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
@@ -206,7 +208,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
logits_processor=get_logits_processor(),
|
||||
**batch
|
||||
@@ -215,7 +217,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
query = batch["input_ids"].detach().cpu()
|
||||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
@@ -240,17 +243,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
|
||||
Both inputs and outputs are put on CPU.
|
||||
"""
|
||||
if self.reward_model is None:
|
||||
if self.finetuning_args.reward_model_type == "api":
|
||||
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
reward_model = self.reward_model if self.reward_model is not None else self.model
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
@@ -259,7 +271,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
if self.reward_model is None:
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
return rewards
|
||||
@@ -298,7 +310,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||
@@ -344,4 +357,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
try:
|
||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||
" zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self._save(output_dir, state_dict={})
|
||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||
self.model.save_checkpoint(output_dir) # wrapped model
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
import json
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
from llmtuner.extras.packages import is_requests_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
rewards = json.loads(response.text)["scores"]
|
||||
return torch.Tensor(rewards)
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
|
||||
@@ -28,14 +28,14 @@ def run_ppo(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# Create reference model and reward model
|
||||
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
|
||||
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Create ppo config
|
||||
|
||||
@@ -22,7 +22,7 @@ def run_pt(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
|
||||
@@ -39,7 +39,9 @@ class PairwiseTrainer(Trainer):
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
|
||||
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
|
||||
@@ -25,9 +25,9 @@ def run_rm(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
|
||||
@@ -26,7 +26,7 @@ def run_sft(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
@@ -34,7 +34,7 @@ def run_sft(
|
||||
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
|
||||
@@ -34,15 +34,15 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
if getattr(model, "quantization_method", None) == "gptq":
|
||||
raise ValueError("Cannot export a GPTQ quantized model.")
|
||||
if getattr(model, "quantization_method", None) in ["gptq", "awq"]:
|
||||
raise ValueError("Cannot export a GPTQ or AWQ quantized model.")
|
||||
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
|
||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size))
|
||||
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
@@ -35,7 +35,7 @@ def create_modelcard_and_push(
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
stage: Literal["ppo", "dpo"]
|
||||
add_valuehead: Optional[bool] = False
|
||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
@@ -51,13 +51,17 @@ def create_ref_model(
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
||||
ref_model, _ = load_model_and_tokenizer(
|
||||
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||
else:
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
||||
ref_model, _ = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from the model itself.")
|
||||
|
||||
return ref_model
|
||||
@@ -71,7 +75,11 @@ def create_reward_model(
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
if finetuning_args.reward_model_type == "lora":
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
||||
return finetuning_args.reward_model
|
||||
elif finetuning_args.reward_model_type == "lora":
|
||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
@@ -93,7 +101,9 @@ def create_reward_model(
|
||||
))
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
reward_model, _ = load_model_and_tokenizer(
|
||||
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
|
||||
@@ -90,6 +90,7 @@ class WebChatModel(ChatModel):
|
||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
gr.Warning(ALERTS["err_demo"][lang])
|
||||
yield ALERTS["err_demo"][lang]
|
||||
return
|
||||
|
||||
|
||||
@@ -11,14 +11,21 @@ from transformers.utils import (
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
)
|
||||
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
from llmtuner.extras.constants import (
|
||||
DEFAULT_MODULE,
|
||||
DEFAULT_TEMPLATE,
|
||||
SUPPORTED_MODELS,
|
||||
TRAINING_STAGES,
|
||||
DownloadSource
|
||||
)
|
||||
from llmtuner.extras.misc import use_modelscope
|
||||
from llmtuner.hparams.data_args import DATA_CONFIG
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
CKPT_NAMES = [
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
@@ -58,7 +65,15 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, [])
|
||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "")
|
||||
if (
|
||||
use_modelscope()
|
||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||
): # replace path
|
||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||
return model_path
|
||||
|
||||
|
||||
def get_prefix(model_name: str) -> str:
|
||||
@@ -89,12 +104,12 @@ def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
|
||||
except Exception as err:
|
||||
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@@ -38,10 +38,11 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature})
|
||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||
elem_dict.update(dict(
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
|
||||
@@ -40,18 +40,19 @@ def save_model(
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
export_dir=export_dir
|
||||
export_dir=export_dir,
|
||||
export_size=max_shard_size
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
export_model(args)
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
|
||||
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
@@ -49,7 +49,10 @@ class Engine:
|
||||
else:
|
||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({"train.output_dir": {"value": get_time()}})
|
||||
yield self._form_dict({
|
||||
"train.output_dir": {"value": "train_" + get_time()},
|
||||
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||
})
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
|
||||
@@ -132,7 +132,7 @@ LOCALES = {
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
"info": "Path of the data directory."
|
||||
"info": "Path to the data directory."
|
||||
},
|
||||
"zh": {
|
||||
"label": "数据路径",
|
||||
@@ -475,12 +475,12 @@ LOCALES = {
|
||||
},
|
||||
"output_dir": {
|
||||
"en": {
|
||||
"label": "Checkpoint name",
|
||||
"info": "Directory to save checkpoint."
|
||||
"label": "Output dir",
|
||||
"info": "Directory for saving results."
|
||||
},
|
||||
"zh": {
|
||||
"label": "断点名称",
|
||||
"info": "保存模型断点的文件夹名称。"
|
||||
"label": "输出目录",
|
||||
"info": "保存结果的路径。"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
|
||||
@@ -87,9 +87,9 @@ class Runner:
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
@@ -160,15 +160,11 @@ class Runner:
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||
)
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
@@ -192,7 +188,7 @@ class Runner:
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=output_dir
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
|
||||
)
|
||||
|
||||
if get("eval.predict"):
|
||||
@@ -242,6 +238,7 @@ class Runner:
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||
))
|
||||
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
|
||||
@@ -44,7 +44,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
|
||||
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
|
||||
for k, v in args.items():
|
||||
if v is not None and v != "":
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
|
||||
Reference in New Issue
Block a user