Compare commits
99 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
790a31404a | ||
|
|
f927601702 | ||
|
|
c4654d54d7 | ||
|
|
df777c30d1 | ||
|
|
d81ad2d4bc | ||
|
|
9f77e8b025 | ||
|
|
04dc3f4614 | ||
|
|
7d1fe50977 | ||
|
|
c0e5e3c5d5 | ||
|
|
3a45cfb604 | ||
|
|
393e4b0f5a | ||
|
|
296711d502 | ||
|
|
9121722999 | ||
|
|
d8d74091f6 | ||
|
|
33521fb45e | ||
|
|
e5204e60ed | ||
|
|
0409428d87 | ||
|
|
f902b0d420 | ||
|
|
27ef5b1aa7 | ||
|
|
c32303fc7e | ||
|
|
45abe361ba | ||
|
|
3ae479faae | ||
|
|
5698038f49 | ||
|
|
020233f725 | ||
|
|
6f9d55b8eb | ||
|
|
2542b62d77 | ||
|
|
95678bb6b1 | ||
|
|
a78759e7ee | ||
|
|
cc5c523f58 | ||
|
|
e39bbdd287 | ||
|
|
d9a50bf93f | ||
|
|
934d00ea1e | ||
|
|
c27675f70d | ||
|
|
7c9f37c83d | ||
|
|
b9736c13e0 | ||
|
|
c47725ff34 | ||
|
|
3ee3fe0bbb | ||
|
|
e54dad75da | ||
|
|
39c2f03eab | ||
|
|
fb9e1c4087 | ||
|
|
ed26bb3d82 | ||
|
|
0baf32e219 | ||
|
|
79a376d1db | ||
|
|
b634e91c43 | ||
|
|
9e2cc21d04 | ||
|
|
6975124a57 | ||
|
|
9f69307db1 | ||
|
|
c3448a045c | ||
|
|
95c561983c | ||
|
|
7a03c8dab5 | ||
|
|
f3ffa8310f | ||
|
|
596f496f19 | ||
|
|
2e6ed731cf | ||
|
|
24ce319b6f | ||
|
|
7b7bfea37d | ||
|
|
3be461260a | ||
|
|
8dab8d9831 | ||
|
|
fb4c5f3c91 | ||
|
|
5fe3cce5a3 | ||
|
|
09f165d442 | ||
|
|
60aea7521b | ||
|
|
29545d0e5e | ||
|
|
4a14099cfd | ||
|
|
b052574ddf | ||
|
|
5ea6a7c6d6 | ||
|
|
8ca196d51f | ||
|
|
5f572cbd77 | ||
|
|
679bd3ab30 | ||
|
|
da3d59fada | ||
|
|
835d27151d | ||
|
|
f1d7228a74 | ||
|
|
72bbd5bdef | ||
|
|
ad9d866547 | ||
|
|
a1ec668b70 | ||
|
|
389687a56d | ||
|
|
97280c73b9 | ||
|
|
f3c622b665 | ||
|
|
d71e8d8dbf | ||
|
|
02c2089ac8 | ||
|
|
07ad28a053 | ||
|
|
d323ccc3ec | ||
|
|
4738d002c7 | ||
|
|
ec099b0586 | ||
|
|
a51253fea2 | ||
|
|
304ec9ec6a | ||
|
|
8547085615 | ||
|
|
14b139ecb5 | ||
|
|
7b45f5068f | ||
|
|
99ceee840e | ||
|
|
8ed68301e3 | ||
|
|
664267e050 | ||
|
|
7ef8f46591 | ||
|
|
6933c1fed2 | ||
|
|
9d125bf533 | ||
|
|
08d5340bd8 | ||
|
|
0e6f4f981e | ||
|
|
670ee3934f | ||
|
|
569860d7ac | ||
|
|
953a562ec1 |
7
.gitignore
vendored
7
.gitignore
vendored
@@ -157,4 +157,9 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# custom .gitignore
|
||||
user.config
|
||||
saves/
|
||||
cache/
|
||||
|
||||
107
README.md
107
README.md
@@ -1,4 +1,4 @@
|
||||
# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
|
||||

|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
@@ -6,7 +6,7 @@
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/c2EPEt5NU)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
|
||||
@@ -44,20 +44,30 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||

|
||||
|
||||
<details><summary>Definitions</summary>
|
||||
|
||||
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
||||
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
||||
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
||||
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
|
||||
|
||||
</details>
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||
|
||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||
|
||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||
|
||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||
|
||||
[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
|
||||
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||
|
||||
@@ -77,6 +87,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
|
||||
</details>
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Default module | Template |
|
||||
@@ -91,8 +103,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -156,6 +169,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
@@ -171,6 +185,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -192,7 +207,15 @@ huggingface-cli login
|
||||
- gradio and matplotlib (used in web UI)
|
||||
- uvicorn, fastapi and sse-starlette (used in API)
|
||||
|
||||
And **powerful GPUs**!
|
||||
### Hardware Requirement
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -219,6 +242,28 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### Use ModelScope Hub (optional)
|
||||
|
||||
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||
|
||||
```bash
|
||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||
... # arguments (same as above)
|
||||
```
|
||||
|
||||
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||
```
|
||||
|
||||
### Train on a single GPU
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -229,8 +274,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset wiki_demo \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
@@ -252,8 +297,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
@@ -276,14 +321,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -301,14 +346,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
--output_dir path_to_ppo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
@@ -332,14 +377,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -427,20 +472,26 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--export_dir path_to_export
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Merging LoRA weights into a quantized model is not supported.
|
||||
|
||||
> [!TIP]
|
||||
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model.
|
||||
|
||||
### API Demo
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
@@ -451,9 +502,9 @@ python src/api_demo.py \
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
@@ -461,9 +512,9 @@ python src/cli_demo.py \
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
@@ -471,9 +522,9 @@ python src/web_demo.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--finetuning_type lora
|
||||
--task mmlu \
|
||||
--split test \
|
||||
--lang en \
|
||||
@@ -486,12 +537,12 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_predict \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
|
||||
113
README_zh.md
113
README_zh.md
@@ -1,4 +1,4 @@
|
||||
# LLaMA Factory: 轻松的大模型训练与评估
|
||||

|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
@@ -6,7 +6,7 @@
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/c2EPEt5NU)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
|
||||
@@ -31,7 +31,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [模型](#模型)
|
||||
- [训练方法](#训练方法)
|
||||
- [数据集](#数据集)
|
||||
- [软件依赖](#软件依赖)
|
||||
- [软硬件依赖](#软硬件依赖)
|
||||
- [如何使用](#如何使用)
|
||||
- [使用了 LLaMA Factory 的项目](#使用了-llama-factory-的项目)
|
||||
- [协议](#协议)
|
||||
@@ -44,26 +44,36 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||

|
||||
|
||||
<details><summary>变量定义</summary>
|
||||
|
||||
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024)
|
||||
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
|
||||
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
|
||||
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`。
|
||||
|
||||
</details>
|
||||
|
||||
## 更新日志
|
||||
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||
|
||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
||||
|
||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||
|
||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||
|
||||
[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||
|
||||
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||
|
||||
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
|
||||
|
||||
[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||
[23/07/31] 我们支持了**数据流式加载**。请使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||
|
||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
||||
|
||||
@@ -77,6 +87,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
|
||||
</details>
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
@@ -91,8 +103,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -156,6 +169,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
@@ -171,6 +185,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -183,7 +198,7 @@ pip install --upgrade huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## 软件依赖
|
||||
## 软硬件依赖
|
||||
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
@@ -192,7 +207,15 @@ huggingface-cli login
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
|
||||
以及 **强而有力的 GPU**!
|
||||
### 硬件依赖
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||
|
||||
## 如何使用
|
||||
|
||||
@@ -219,6 +242,28 @@ pip install -r requirements.txt
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### 使用魔搭社区(可跳过)
|
||||
|
||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
|
||||
```bash
|
||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
```
|
||||
|
||||
接着即可通过指定模型名称来训练对应的模型。(在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||
... # 参数同上
|
||||
```
|
||||
|
||||
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||
```
|
||||
|
||||
### 单 GPU 训练
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -229,8 +274,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset wiki_demo \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
@@ -252,8 +297,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
@@ -276,14 +321,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -301,14 +346,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
--output_dir path_to_ppo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
@@ -332,14 +377,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -427,20 +472,26 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--export_dir path_to_export
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> 尚不支持量化模型的 LoRA 权重合并及导出。
|
||||
|
||||
> [!TIP]
|
||||
> 使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化导出模型。
|
||||
|
||||
### API 服务
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
@@ -451,9 +502,9 @@ python src/api_demo.py \
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### 浏览器测试
|
||||
@@ -461,9 +512,9 @@ python src/cli_demo.py \
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### 模型评估
|
||||
@@ -471,9 +522,9 @@ python src/web_demo.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--finetuning_type lora \
|
||||
--task ceval \
|
||||
--split validation \
|
||||
--lang zh \
|
||||
@@ -486,12 +537,12 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_predict \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
|
||||
@@ -4,9 +4,10 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
||||
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"columns": {
|
||||
@@ -16,7 +17,8 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
||||
"content": "the key in the message represents the content. (default: value, for sharegpt)"
|
||||
"content": "the key in the message represents the content. (default: value, for sharegpt)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None, for both)"
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -31,6 +33,7 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
||||
"instruction": "user instruction (required)",
|
||||
"input": "user input (optional)",
|
||||
"output": "model response (required)",
|
||||
"system": "system prompt (optional)",
|
||||
"history": [
|
||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
@@ -47,6 +50,7 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"system": "system",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
@@ -54,7 +58,7 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
||||
|
||||
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||
The `system` column will be used as the system prompt in the template. The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||
|
||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||
|
||||
@@ -85,7 +89,8 @@ The dataset in sharegpt format should follow the below format:
|
||||
"from": "gpt",
|
||||
"value": "model response"
|
||||
}
|
||||
]
|
||||
],
|
||||
"system": "system prompt (optional)"
|
||||
}
|
||||
]
|
||||
```
|
||||
@@ -97,7 +102,8 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
"content": "value",
|
||||
"system": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"hf_hub_url": "Hugging Face 的仓库地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选,留空不影响训练)",
|
||||
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"columns": {
|
||||
@@ -16,7 +17,8 @@
|
||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)"
|
||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None,用于两种格式)"
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -31,6 +33,7 @@
|
||||
"instruction": "用户指令(必填)",
|
||||
"input": "用户输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"system": "系统提示词(选填)",
|
||||
"history": [
|
||||
["第一轮指令(选填)", "第一轮回答(选填)"],
|
||||
["第二轮指令(选填)", "第二轮回答(选填)"]
|
||||
@@ -47,6 +50,7 @@
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"system": "system",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
@@ -54,7 +58,7 @@
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
||||
|
||||
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||
`system` 为模板中的系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||
|
||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||
|
||||
@@ -85,7 +89,8 @@
|
||||
"from": "gpt",
|
||||
"value": "模型回答"
|
||||
}
|
||||
]
|
||||
],
|
||||
"system": "系统提示词(选填)"
|
||||
}
|
||||
]
|
||||
```
|
||||
@@ -97,7 +102,8 @@
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
"content": "value",
|
||||
"system": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
||||
@@ -1,9 +1,9 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.31.0,<4.35.0
|
||||
datasets>=2.14.0
|
||||
transformers>=4.36.1
|
||||
datasets>=2.14.3
|
||||
accelerate>=0.21.0
|
||||
peft>=0.6.0
|
||||
trl>=0.7.4
|
||||
peft>=0.7.0
|
||||
trl==0.7.4
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
sentencepiece
|
||||
|
||||
@@ -7,4 +7,4 @@ from llmtuner.train import export_model, run_exp
|
||||
from llmtuner.webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.3.2"
|
||||
__version__ = "0.4.0"
|
||||
|
||||
@@ -15,7 +15,9 @@ from llmtuner.api.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage
|
||||
ChatCompletionResponseUsage,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse
|
||||
)
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
@@ -68,6 +70,9 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if not chat_model.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
@@ -123,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
@@ -156,6 +161,17 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
yield to_json(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -81,3 +81,16 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[str]
|
||||
max_length: Optional[int] = None
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Optional[str] = "scoreeval-default"
|
||||
object: Optional[str] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import tiktoken
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||
from threading import Thread
|
||||
@@ -22,11 +23,13 @@ class ChatModel:
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.can_generate = (finetuning_args.stage == "sft")
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.system_prompt = data_args.system_prompt
|
||||
|
||||
def _process_args(
|
||||
self,
|
||||
@@ -35,7 +38,6 @@ class ChatModel:
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
system = system or self.system_prompt
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
)
|
||||
@@ -130,3 +132,41 @@ class ChatModel:
|
||||
thread.start()
|
||||
|
||||
yield from streamer
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs
|
||||
) -> List[float]:
|
||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=True)
|
||||
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||
|
||||
inputs = self.tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||
pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
**kwargs
|
||||
).to(device)
|
||||
|
||||
input_ids: torch.Tensor = inputs["input_ids"]
|
||||
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
scores = []
|
||||
for i in range(input_ids.size(0)):
|
||||
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
scores.append(values[i, end_index].nan_to_num().item())
|
||||
|
||||
return scores
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
from llmtuner.data.utils import checksum, EXT2TYPE
|
||||
from llmtuner.data.utils import checksum
|
||||
from llmtuner.extras.constants import FILEEXT2TYPE
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -24,27 +25,27 @@ def get_dataset(
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
data_dir = dataset_attr.folder
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_path, data_name = None, None
|
||||
data_files: List[str] = []
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
data_files = []
|
||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
if os.path.isdir(local_path): # is directory
|
||||
for file_name in os.listdir(local_path):
|
||||
data_files.append(os.path.join(local_path, file_name))
|
||||
if data_path is None:
|
||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
||||
assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
||||
elif os.path.isfile(local_path): # is file
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
@@ -53,17 +54,37 @@ def get_dataset(
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||
)
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
try:
|
||||
from modelscope import MsDataset # type: ignore
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"):
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_path,
|
||||
subset_name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
).to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
else:
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if max_samples is not None: # truncate dataset
|
||||
@@ -71,8 +92,8 @@ def get_dataset(
|
||||
|
||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# convert dataset from sharegpt format to alpaca format
|
||||
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
||||
for msg_list in examples[dataset_attr.messages]:
|
||||
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
|
||||
for i, msg_list in enumerate(examples[dataset_attr.messages]):
|
||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
||||
if len(msg_list) == 0:
|
||||
continue
|
||||
@@ -95,7 +116,8 @@ def get_dataset(
|
||||
outputs["prompt"].append(msg_pairs[-1][0])
|
||||
outputs["query"].append("")
|
||||
outputs["response"].append(msg_pairs[-1][1])
|
||||
outputs["history"].append(msg_pairs[:-1])
|
||||
outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -116,17 +138,10 @@ def get_dataset(
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
||||
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.system_prompt: # add system prompt
|
||||
system_prompt = dataset_attr.system_prompt
|
||||
if data_args.streaming:
|
||||
dataset = dataset.map(lambda _: {"system": system_prompt})
|
||||
else:
|
||||
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
|
||||
@@ -408,18 +408,31 @@ register_template(
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
"User: {{query}}\n\nAssistant:"
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer."
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
sep=[
|
||||
"\n",
|
||||
{"token": "<|EOT|>"},
|
||||
"\n\n"
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|EOT|>"
|
||||
@@ -528,9 +541,7 @@ register_template(
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
" "
|
||||
]
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
@@ -606,9 +617,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
register_template(
|
||||
name="vanilla",
|
||||
prefix=[],
|
||||
@@ -637,6 +645,23 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xuanyuan",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}} Assistant:"
|
||||
],
|
||||
system=(
|
||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xverse",
|
||||
prefix=[
|
||||
@@ -682,6 +707,25 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yi",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"<|im_end|>\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="zephyr",
|
||||
prefix=[
|
||||
|
||||
@@ -12,16 +12,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
EXT2TYPE = {
|
||||
"arrow": "arrow",
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"parquet": "parquet",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
if file_sha1 is None:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
@@ -18,6 +19,16 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir)
|
||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -25,25 +36,17 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir)
|
||||
_save_model_with_valuehead(
|
||||
model=unwrap_model(kwargs.pop("model")),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||
model.pretrained_model.config.save_pretrained(args.output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(args.output_dir)
|
||||
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Dict, Optional
|
||||
|
||||
@@ -8,6 +9,15 @@ DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
FILEEXT2TYPE = {
|
||||
"arrow": "arrow",
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"parquet": "parquet",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
@@ -16,6 +26,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
@@ -28,9 +40,13 @@ TRAINING_STAGES = {
|
||||
"Pre-Training": "pt"
|
||||
}
|
||||
|
||||
class DownloadSource(str, Enum):
|
||||
DEFAULT = "hf"
|
||||
MODELSCOPE = "ms"
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, str],
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None
|
||||
) -> None:
|
||||
@@ -49,9 +65,18 @@ def register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
|
||||
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
|
||||
"Baichuan-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
|
||||
},
|
||||
"Baichuan-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
|
||||
},
|
||||
"Baichuan-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
|
||||
}
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan"
|
||||
@@ -60,10 +85,22 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
|
||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
|
||||
"Baichuan2-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
|
||||
},
|
||||
"Baichuan2-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
|
||||
},
|
||||
"Baichuan2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
|
||||
},
|
||||
"Baichuan2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
|
||||
}
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan2"
|
||||
@@ -72,9 +109,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1"
|
||||
"BLOOM-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
|
||||
},
|
||||
"BLOOM-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
|
||||
},
|
||||
"BLOOM-7B1": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
|
||||
}
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
@@ -82,9 +128,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
|
||||
"BLOOMZ-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
|
||||
},
|
||||
"BLOOMZ-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
|
||||
},
|
||||
"BLOOMZ-7B1-mt": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
|
||||
}
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
@@ -92,8 +147,14 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
|
||||
"BlueLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
|
||||
},
|
||||
"BlueLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
|
||||
}
|
||||
},
|
||||
template="bluelm"
|
||||
)
|
||||
@@ -101,7 +162,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||
"ChatGLM2-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm2"
|
||||
@@ -110,8 +174,14 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
|
||||
"ChatGLM3-6B-Base": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
|
||||
},
|
||||
"ChatGLM3-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm3"
|
||||
@@ -120,12 +190,30 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
||||
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
|
||||
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
|
||||
"ChineseLLaMA2-1.3B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
|
||||
},
|
||||
"ChineseLLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
|
||||
},
|
||||
"ChineseLLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
|
||||
},
|
||||
"ChineseLLaMA2-1.3B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
|
||||
},
|
||||
"ChineseLLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
|
||||
},
|
||||
"ChineseLLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
|
||||
}
|
||||
},
|
||||
template="llama2_zh"
|
||||
)
|
||||
@@ -133,12 +221,76 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": "tiiuae/falcon-7b",
|
||||
"Falcon-40B": "tiiuae/falcon-40b",
|
||||
"Falcon-180B": "tiiuae/falcon-180B",
|
||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
|
||||
"DeepseekLLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
|
||||
},
|
||||
"DeepseekLLM-67B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
|
||||
},
|
||||
"DeepseekLLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
|
||||
},
|
||||
"DeepseekLLM-67B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
|
||||
}
|
||||
},
|
||||
template="deepseek"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepseekCoder-6.7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
|
||||
},
|
||||
"DeepseekCoder-33B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
|
||||
},
|
||||
"DeepseekCoder-6.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
|
||||
},
|
||||
"DeepseekCoder-33B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
|
||||
}
|
||||
},
|
||||
template="deepseekcoder"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
|
||||
},
|
||||
"Falcon-40B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
|
||||
},
|
||||
"Falcon-180B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
|
||||
},
|
||||
"Falcon-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
|
||||
},
|
||||
"Falcon-40B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
|
||||
},
|
||||
"Falcon-180B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon"
|
||||
@@ -147,10 +299,22 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": "internlm/internlm-7b",
|
||||
"InternLM-20B": "internlm/internlm-20b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
|
||||
"InternLM-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
|
||||
},
|
||||
"InternLM-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
|
||||
},
|
||||
"InternLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
|
||||
},
|
||||
"InternLM-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
|
||||
}
|
||||
},
|
||||
template="intern"
|
||||
)
|
||||
@@ -158,7 +322,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
||||
"LingoWhale-8B": {
|
||||
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
|
||||
}
|
||||
},
|
||||
module="qkv_proj"
|
||||
)
|
||||
@@ -166,22 +333,52 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": "huggyllama/llama-7b",
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
"LLaMA-30B": "huggyllama/llama-30b",
|
||||
"LLaMA-65B": "huggyllama/llama-65b"
|
||||
"LLaMA-7B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
|
||||
},
|
||||
"LLaMA-13B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
|
||||
},
|
||||
"LLaMA-30B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
|
||||
},
|
||||
"LLaMA-65B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
|
||||
"LLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
|
||||
},
|
||||
"LLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
|
||||
},
|
||||
"LLaMA2-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
|
||||
},
|
||||
"LLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
|
||||
},
|
||||
"LLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
|
||||
},
|
||||
"LLaMA2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
|
||||
}
|
||||
},
|
||||
template="llama2"
|
||||
)
|
||||
@@ -189,8 +386,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
||||
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
"Mistral-7B": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
|
||||
},
|
||||
"Mistral-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
||||
},
|
||||
"Mistral-7B-v0.2-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2"
|
||||
}
|
||||
},
|
||||
template="mistral"
|
||||
)
|
||||
@@ -198,7 +405,25 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
|
||||
"Mixtral-8x7B": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1"
|
||||
},
|
||||
"Mixtral-8x7B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
|
||||
}
|
||||
},
|
||||
template="mistral"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"OpenChat3.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
|
||||
}
|
||||
},
|
||||
template="openchat"
|
||||
)
|
||||
@@ -206,7 +431,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||
"Phi1.5-1.3B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
||||
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
|
||||
}
|
||||
},
|
||||
module="Wqkv"
|
||||
)
|
||||
@@ -214,10 +442,70 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
|
||||
"Qwen-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
|
||||
},
|
||||
"Qwen-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
|
||||
},
|
||||
"Qwen-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
|
||||
},
|
||||
"Qwen-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
|
||||
},
|
||||
"Qwen-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
|
||||
},
|
||||
"Qwen-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
|
||||
},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
|
||||
},
|
||||
"Qwen-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
|
||||
},
|
||||
"Qwen-1.8B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
|
||||
},
|
||||
"Qwen-1.8B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
|
||||
},
|
||||
"Qwen-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
|
||||
},
|
||||
"Qwen-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
|
||||
},
|
||||
"Qwen-14B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
|
||||
},
|
||||
"Qwen-14B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
|
||||
},
|
||||
"Qwen-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
|
||||
},
|
||||
"Qwen-72B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
|
||||
}
|
||||
},
|
||||
module="c_attn",
|
||||
template="qwen"
|
||||
@@ -226,15 +514,24 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
||||
"Skywork-13B-Base": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
|
||||
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
|
||||
"Vicuna1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
|
||||
},
|
||||
"Vicuna1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
|
||||
}
|
||||
},
|
||||
template="vicuna"
|
||||
)
|
||||
@@ -242,11 +539,49 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||
"XVERSE-65B": "xverse/XVERSE-65B",
|
||||
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
|
||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
|
||||
"XuanYuan-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"
|
||||
},
|
||||
"XuanYuan-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
|
||||
},
|
||||
"XuanYuan-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
|
||||
},
|
||||
"XuanYuan-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
|
||||
}
|
||||
},
|
||||
template="xuanyuan"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
|
||||
},
|
||||
"XVERSE-13B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
|
||||
},
|
||||
"XVERSE-65B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
|
||||
},
|
||||
"XVERSE-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
|
||||
},
|
||||
"XVERSE-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
|
||||
},
|
||||
"XVERSE-65B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat"
|
||||
}
|
||||
},
|
||||
template="xverse"
|
||||
)
|
||||
@@ -254,8 +589,14 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yayi-7B": "wenge-research/yayi-7b-llama2",
|
||||
"Yayi-13B": "wenge-research/yayi-13b-llama2"
|
||||
"Yayi-7B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
|
||||
},
|
||||
"Yayi-13B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
|
||||
}
|
||||
},
|
||||
template="yayi"
|
||||
)
|
||||
@@ -263,16 +604,45 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": "01-ai/Yi-6B",
|
||||
"Yi-34B": "01-ai/Yi-34B"
|
||||
}
|
||||
"Yi-6B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B"
|
||||
},
|
||||
"Yi-34B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
|
||||
},
|
||||
"Yi-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"
|
||||
},
|
||||
"Yi-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
|
||||
},
|
||||
"Yi-6B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits"
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
|
||||
}
|
||||
},
|
||||
template="yi"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
|
||||
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
|
||||
"Zephyr-7B-Alpha-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
|
||||
},
|
||||
"Zephyr-7B-Beta-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
|
||||
}
|
||||
},
|
||||
template="zephyr"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
|
||||
try:
|
||||
@@ -23,6 +22,7 @@ except ImportError:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
@@ -67,13 +67,18 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_current_device() -> str:
|
||||
def get_current_device() -> torch.device:
|
||||
import accelerate
|
||||
dummy_accelerator = accelerate.Accelerator()
|
||||
if accelerate.utils.is_xpu_available():
|
||||
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif accelerate.utils.is_npu_available():
|
||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
else:
|
||||
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
||||
device = "cpu"
|
||||
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
@@ -97,17 +102,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
@@ -116,3 +110,23 @@ def torch_gc() -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||
return
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
model_args.model_name_or_path = snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=revision,
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||
|
||||
@@ -13,43 +13,37 @@ def get_package_version(name: str) -> str:
|
||||
return "0.0.0"
|
||||
|
||||
|
||||
_fastapi_available = is_package_available("fastapi")
|
||||
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||
_jieba_available = is_package_available("jieba")
|
||||
_matplotlib_available = is_package_available("matplotlib")
|
||||
_nltk_available = is_package_available("nltk")
|
||||
_rouge_available = is_package_available("rouge_chinese")
|
||||
_starlette_available = is_package_available("sse_starlette")
|
||||
_uvicorn_available = is_package_available("uvicorn")
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
return _fastapi_available
|
||||
return is_package_available("fastapi")
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _flash_attn2_available
|
||||
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _jieba_available
|
||||
return is_package_available("jieba")
|
||||
|
||||
|
||||
def is_matplotlib_available():
|
||||
return _matplotlib_available
|
||||
return is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return _nltk_available
|
||||
return is_package_available("nltk")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return is_package_available("requests")
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return _rouge_available
|
||||
return is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _starlette_available
|
||||
return is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _uvicorn_available
|
||||
return is_package_available("uvicorn")
|
||||
|
||||
@@ -4,14 +4,21 @@ from typing import List, Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
@@ -22,6 +29,7 @@ class DatasetAttr:
|
||||
messages: Optional[str] = "conversations"
|
||||
role: Optional[str] = "from"
|
||||
content: Optional[str] = "value"
|
||||
system: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
@@ -96,10 +104,6 @@ class DataArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
||||
)
|
||||
val_size: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||
@@ -130,29 +134,40 @@ class DataArguments:
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
except Exception as err:
|
||||
if self.dataset is not None:
|
||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for i, name in enumerate(dataset_names):
|
||||
for name in dataset_names:
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
dataset_attr = DatasetAttr(
|
||||
"ms_hub",
|
||||
dataset_name=dataset_info[name]["ms_hub_url"]
|
||||
)
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"hf_hub",
|
||||
dataset_name=dataset_info[name]["hf_hub_url"]
|
||||
)
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
dataset_attr = DatasetAttr(
|
||||
"script",
|
||||
dataset_name=dataset_info[name]["script_url"]
|
||||
)
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
@@ -168,9 +183,10 @@ class DataArguments:
|
||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
||||
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
dataset_attr.system_prompt = prompt_list[i]
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
@@ -15,7 +15,7 @@ class FreezeArguments:
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
||||
Qwen choices: [\"mlp\", \"attn\"], \
|
||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||
Phi choices: [\"mlp\", \"mixer\"], \
|
||||
Others choices: the same as LLaMA."}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
@@ -33,9 +33,9 @@ class LoraArguments:
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
||||
)
|
||||
lora_alpha: Optional[float] = field(
|
||||
lora_alpha: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.1,
|
||||
@@ -52,12 +52,12 @@ class LoraArguments:
|
||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||
Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||
Others choices: the same as LLaMA."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
create_new_adapter: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to create a new adapter with randomly initialized weight or not."}
|
||||
)
|
||||
|
||||
|
||||
@@ -70,6 +70,14 @@ class RLHFArguments:
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."}
|
||||
)
|
||||
dpo_ftx: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
||||
)
|
||||
ppo_buffer_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
||||
@@ -98,9 +106,9 @@ class RLHFArguments:
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
||||
)
|
||||
ref_model_checkpoint: Optional[str] = field(
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||
metadata={"help": "Path to the adapters of the reference model."}
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
@@ -108,24 +116,55 @@ class RLHFArguments:
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
metadata={"help": "Path to the reward model used for the PPO training."}
|
||||
)
|
||||
reward_model_checkpoint: Optional[str] = field(
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
|
||||
metadata={"help": "Path to the adapters of the reward model."}
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."}
|
||||
)
|
||||
reward_model_type: Optional[Literal["lora", "full"]] = field(
|
||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
class ExportArguments:
|
||||
r"""
|
||||
Arguments pertaining to model exporting.
|
||||
"""
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
export_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."}
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."}
|
||||
)
|
||||
export_quantization_dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
||||
)
|
||||
export_quantization_nsamples: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={"help": "The number of samples used for quantization."}
|
||||
)
|
||||
export_quantization_maxlen: Optional[str] = field(
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs used for quantization."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
@@ -141,14 +180,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
neft_alpha: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||
)
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
@@ -161,21 +192,25 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
return arg
|
||||
|
||||
self.name_module_trainable = split_arg(self.name_module_trainable)
|
||||
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
|
||||
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
|
||||
self.ref_model_adapters = split_arg(self.ref_model_adapters)
|
||||
self.reward_model_adapters = split_arg(self.reward_model_adapters)
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("Lora reward model only supports lora training.")
|
||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||
|
||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
||||
@@ -8,11 +8,15 @@ class ModelArguments:
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
||||
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=True,
|
||||
@@ -42,10 +46,6 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||
@@ -58,6 +58,10 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
ms_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype = None
|
||||
@@ -66,8 +70,8 @@ class ModelArguments:
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ def init_adapter(
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||
return model
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
@@ -44,6 +44,7 @@ def init_adapter(
|
||||
)
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
@@ -62,32 +63,33 @@ def init_adapter(
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
checkpoint_to_resume = None
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
is_mergeable = True
|
||||
if getattr(model, "quantization_method", None) == "gptq":
|
||||
assert len(model_args.checkpoint_dir) == 1, "GPTQ quantized model only accepts a single checkpoint."
|
||||
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable):
|
||||
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
adapter_to_merge = model_args.adapter_name_or_path
|
||||
|
||||
for checkpoint in checkpoints_to_merge:
|
||||
model = PeftModel.from_pretrained(model, checkpoint)
|
||||
for adapter in adapter_to_merge:
|
||||
model = PeftModel.from_pretrained(model, adapter)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
|
||||
if checkpoint_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
target_modules = find_all_linear_modules(model)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
@@ -102,7 +104,10 @@ def init_adapter(
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,61 +1,46 @@
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
import llmtuner.model.patcher as patcher
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import count_parameters, try_download_model_from_ms
|
||||
from llmtuner.model.adapter import init_adapter
|
||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
||||
from llmtuner.model.utils import (
|
||||
load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
||||
require_version("transformers>=4.36.1", "To fix: pip install transformers>=4.36.1")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||
require_version("trl==0.7.4", "To fix: pip install trl==0.7.4")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||
add_valuehead: Optional[bool] = False
|
||||
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
try_download_model_from_ms(model_args)
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
@@ -70,151 +55,49 @@ def load_model_and_tokenizer(
|
||||
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
||||
**config_kwargs
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
patcher.patch_tokenizer(tokenizer)
|
||||
patcher.patch_config(config, model_args)
|
||||
patcher.configure_rope(config, model_args, is_trainable)
|
||||
patcher.configure_flashattn(config_kwargs, model_args)
|
||||
patcher.configure_longlora(config, model_args, is_trainable)
|
||||
patcher.configure_quantization(config, config_kwargs, tokenizer, model_args, finetuning_args)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Set model dtype
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
if is_flash_attn2_available():
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
else:
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library)
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load pre-trained models (without valuehead)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
patcher.patch_model(model)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
# Disable custom generate method (for Qwen and Baichuan2)
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
# Register auto class to save the custom code files
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
model.__class__.register_for_auto_class()
|
||||
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage in ["rm", "ppo"]:
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
||||
vhead_path = (
|
||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||
)
|
||||
patcher.patch_valuehead_model(model)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
vhead_path = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
vhead_path = model_args.model_name_or_path
|
||||
|
||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import logging
|
||||
import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@@ -7,7 +9,6 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import parse_args
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
@@ -40,47 +41,72 @@ _EVAL_CLS = Tuple[
|
||||
]
|
||||
|
||||
|
||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
||||
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
|
||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||
if model_args.quantization_bit is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if finetuning_args.create_new_adapter:
|
||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Multiple adapters are only available for LoRA tuning.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
log_level = training_args.get_process_log_level()
|
||||
_set_transformers_logging(log_level)
|
||||
|
||||
# Check arguments
|
||||
data_args.init_for_training(training_args.seed)
|
||||
|
||||
@@ -139,11 +165,18 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||
can_resume_from_checkpoint = False
|
||||
training_args.resume_from_checkpoint = None
|
||||
else:
|
||||
can_resume_from_checkpoint = True
|
||||
|
||||
if (
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
@@ -158,7 +191,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
))
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
|
||||
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
))
|
||||
|
||||
@@ -182,7 +215,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
_set_transformers_logging()
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
@@ -193,7 +227,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
_set_transformers_logging()
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
184
src/llmtuner/model/patcher.py
Normal file
184
src/llmtuner/model/patcher.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner.extras.constants import FILEEXT2TYPE
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
|
||||
|
||||
|
||||
def configure_flashattn(config_kwargs: Dict[str, Any], model_args: "ModelArguments"):
|
||||
if model_args.flash_attn and is_flash_attn2_available():
|
||||
config_kwargs["use_flash_attention_2"] = True
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
|
||||
|
||||
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
|
||||
def configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
config_kwargs: Dict[str, Any],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
):
|
||||
if getattr(config, "quantization_config", None): # gptq or awq
|
||||
model_args.quantization_bit = None # remove bnb quantization
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config = getattr(config, "quantization_config", None)
|
||||
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
if model_args.quantization_bit is not None: # bnb
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if finetuning_args.export_quantization_bit is not None: # gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
|
||||
config_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=finetuning_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
|
||||
)
|
||||
config_kwargs["device_map"] = "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
|
||||
|
||||
|
||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||
if model_args.rope_scaling is not None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
|
||||
def get_quantization_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
"""
|
||||
if os.path.isfile(finetuning_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(finetuning_args.export_quantization_dataset.split(".")[-1], None)
|
||||
data_files = finetuning_args.export_quantization_dataset
|
||||
else:
|
||||
data_path = finetuning_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = finetuning_args.export_quantization_maxlen
|
||||
|
||||
samples = []
|
||||
for _ in range(finetuning_args.export_quantization_nsamples):
|
||||
while True:
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx:word_idx+maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
|
||||
def patch_model(model: "PreTrainedModel"):
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
|
||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"):
|
||||
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||
return self.pretrained_model.get_input_embeddings()
|
||||
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
import torch
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from transformers.utils import cached_file
|
||||
@@ -10,7 +10,7 @@ from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
if getattr(model, "quantization_method", None): # already set on current device
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
@@ -42,18 +42,18 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
return model.cuda()
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
"""
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
quantization_method = getattr(model, "quantization_method", None)
|
||||
if quantization_method is None:
|
||||
linear_cls = torch.nn.Linear
|
||||
elif quantization_method == "bitsandbytes":
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
@@ -85,10 +85,7 @@ def get_modelcard_args(
|
||||
}
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
path_or_repo_id: str,
|
||||
model_args: "ModelArguments"
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
@@ -96,22 +93,10 @@ def load_valuehead_params(
|
||||
"""
|
||||
kwargs = {
|
||||
"path_or_repo_id": path_or_repo_id,
|
||||
"cache_dir": model_args.cache_dir
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
if "token" in inspect.signature(cached_file).parameters:
|
||||
kwargs["token"] = model_args.hf_hub_token
|
||||
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
||||
kwargs["use_auth_token"] = model_args.hf_hub_token
|
||||
else:
|
||||
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
@@ -123,10 +108,24 @@ def load_valuehead_params(
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||
return None
|
||||
|
||||
|
||||
def noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(avg_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
@@ -147,17 +146,6 @@ def prepare_model_for_training(
|
||||
param.data = param.data.to(torch.float32)
|
||||
logger.info("Upcasting weights in layernorm in float32.")
|
||||
|
||||
if finetuning_args.neft_alpha > 1e-6:
|
||||
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
if module.training:
|
||||
dims = torch.tensor(output.size(1) * output.size(2))
|
||||
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||
return output
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
||||
|
||||
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
@@ -181,3 +169,31 @@ def prepare_model_for_training(
|
||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Resize token embeddings.
|
||||
"""
|
||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
if len(tokenizer) > current_embedding_size:
|
||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||
logger.warning("Current model does not support resizing token embeddings.")
|
||||
return
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
|
||||
|
||||
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
|
||||
if "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
model.__class__.register_for_auto_class()
|
||||
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
@@ -16,10 +16,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
loss_type: Literal["sigmoid", "hinge"],
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
@@ -34,6 +35,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = beta
|
||||
self.label_smoothing = 0
|
||||
self.ftx_gamma = ftx_gamma
|
||||
self.loss_type = loss_type
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
@@ -51,10 +54,28 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def sft_loss(
|
||||
self,
|
||||
chosen_logits: torch.FloatTensor,
|
||||
chosen_labels: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self._get_batch_logps(
|
||||
chosen_logits,
|
||||
chosen_labels,
|
||||
average_log_prob=True
|
||||
)
|
||||
return -all_logps
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
@@ -73,3 +94,61 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||
|
||||
def get_batch_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor],
|
||||
train_eval: Optional[Literal["train", "eval"]] = "train"
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.model, batch)
|
||||
else:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
)
|
||||
if self.ftx_gamma > 1e-6:
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
|
||||
|
||||
return losses.mean(), metrics
|
||||
|
||||
@@ -25,11 +25,11 @@ def run_dpo(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ def run_dpo(
|
||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||
ref_model = model
|
||||
else:
|
||||
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
@@ -47,6 +47,8 @@ def run_dpo(
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
loss_type=finetuning_args.dpo_loss,
|
||||
ftx_gamma=finetuning_args.dpo_ftx,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
|
||||
@@ -3,9 +3,10 @@ import sys
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from trl import PPOTrainer
|
||||
@@ -14,7 +15,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@@ -55,17 +56,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||
self.accelerator.state, "deepspeed_plugin"
|
||||
)
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
if reward_model is not None:
|
||||
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||
self.accelerator.state, "deepspeed_plugin"
|
||||
)
|
||||
if is_deepspeed_enabled:
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||
@@ -198,15 +199,20 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
logits_processor=get_logits_processor(),
|
||||
**batch
|
||||
@@ -215,10 +221,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
query = batch["input_ids"].detach().cpu()
|
||||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
@@ -226,7 +233,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
||||
return queries, responses
|
||||
@@ -240,17 +247,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
|
||||
Both inputs and outputs are put on CPU.
|
||||
"""
|
||||
if self.reward_model is None:
|
||||
if self.finetuning_args.reward_model_type == "api":
|
||||
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
reward_model = self.reward_model if self.reward_model is not None else self.model
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
@@ -259,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
if self.reward_model is None:
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
return rewards
|
||||
@@ -298,7 +314,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||
@@ -344,4 +361,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
try:
|
||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||
" use zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self._save(output_dir, state_dict={})
|
||||
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
|
||||
file = os.path.join(output_dir, filename)
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
|
||||
self.model.save_checkpoint(output_dir) # wrapped model
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
import json
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
from llmtuner.extras.packages import is_requests_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
rewards = json.loads(response.text)["scores"]
|
||||
return torch.Tensor(rewards)
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
|
||||
@@ -28,14 +28,14 @@ def run_ppo(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# Create reference model and reward model
|
||||
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
|
||||
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Create ppo config
|
||||
|
||||
@@ -22,7 +22,7 @@ def run_pt(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
|
||||
@@ -39,7 +39,9 @@ class PairwiseTrainer(Trainer):
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
|
||||
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
|
||||
@@ -25,9 +25,9 @@ def run_rm(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
|
||||
@@ -26,7 +26,7 @@ def run_sft(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
@@ -34,7 +34,7 @@ def run_sft(
|
||||
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
|
||||
@@ -34,15 +34,20 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.export_quantization_bit is not None:
|
||||
raise ValueError("Please merge adapters before quantizing the model.")
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
if getattr(model, "quantization_method", None) == "gptq":
|
||||
raise ValueError("Cannot export a GPTQ quantized model.")
|
||||
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
|
||||
logger.warning("Cannot merge adapters to a quantized model.")
|
||||
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
|
||||
model = model.to("cpu")
|
||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size))
|
||||
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
@@ -35,7 +35,7 @@ def create_modelcard_and_push(
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
stage: Literal["ppo", "dpo"]
|
||||
add_valuehead: Optional[bool] = False
|
||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
@@ -46,18 +46,22 @@ def create_ref_model(
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.ref_model,
|
||||
checkpoint_dir=finetuning_args.ref_model_checkpoint,
|
||||
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
||||
quantization_bit=finetuning_args.ref_model_quantization_bit
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
||||
ref_model, _ = load_model_and_tokenizer(
|
||||
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||
else:
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
||||
ref_model, _ = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from the model itself.")
|
||||
|
||||
return ref_model
|
||||
@@ -71,7 +75,11 @@ def create_reward_model(
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
if finetuning_args.reward_model_type == "lora":
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
||||
return finetuning_args.reward_model
|
||||
elif finetuning_args.reward_model_type == "lora":
|
||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
@@ -88,12 +96,14 @@ def create_reward_model(
|
||||
reward_model_args_dict = model_args.to_dict()
|
||||
reward_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.reward_model,
|
||||
checkpoint_dir=finetuning_args.reward_model_checkpoint,
|
||||
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
||||
quantization_bit=finetuning_args.reward_model_quantization_bit
|
||||
))
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
reward_model, _ = load_model_and_tokenizer(
|
||||
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
|
||||
@@ -63,21 +63,20 @@ class WebChatModel(ChatModel):
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
adapter_name_or_path = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=get("top.model_path"),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||
@@ -90,6 +89,7 @@ class WebChatModel(ChatModel):
|
||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
gr.Warning(ALERTS["err_demo"][lang])
|
||||
yield ALERTS["err_demo"][lang]
|
||||
return
|
||||
|
||||
|
||||
@@ -2,31 +2,25 @@ import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from typing import Any, Dict, Optional
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.constants import (
|
||||
DEFAULT_MODULE,
|
||||
DEFAULT_TEMPLATE,
|
||||
PEFT_METHODS,
|
||||
SUPPORTED_MODELS,
|
||||
TRAINING_STAGES,
|
||||
DownloadSource
|
||||
)
|
||||
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
from llmtuner.extras.misc import use_modelscope
|
||||
from llmtuner.hparams.data_args import DATA_CONFIG
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
CKPT_NAMES = [
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
]
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
@@ -58,7 +52,15 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, [])
|
||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "")
|
||||
if (
|
||||
use_modelscope()
|
||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||
): # replace path
|
||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||
return model_path
|
||||
|
||||
|
||||
def get_prefix(model_name: str) -> str:
|
||||
@@ -75,26 +77,29 @@ def get_template(model_name: str) -> str:
|
||||
return "default"
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
if model_name:
|
||||
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type not in PEFT_METHODS:
|
||||
return gr.update(value=[], choices=[], interactive=False)
|
||||
|
||||
adapters = []
|
||||
if model_name and finetuning_type == "lora":
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
for adapter in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
||||
os.path.isdir(os.path.join(save_dir, adapter))
|
||||
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
adapters.append(adapter)
|
||||
return gr.update(value=[], choices=adapters, interactive=True)
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
|
||||
except Exception as err:
|
||||
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@@ -21,8 +21,11 @@ def next_page(page_index: int, total_num: int) -> int:
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
except:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
|
||||
@@ -38,10 +38,11 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature})
|
||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||
elem_dict.update(dict(
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
|
||||
@@ -10,14 +10,19 @@ if TYPE_CHECKING:
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
checkpoints: List[str],
|
||||
adapter_path: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_dataset: str,
|
||||
export_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
error = ""
|
||||
@@ -25,34 +30,46 @@ def save_model(
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not model_path:
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif not checkpoints:
|
||||
error = ALERTS["err_no_checkpoint"][lang]
|
||||
elif not export_dir:
|
||||
error = ALERTS["err_no_export_dir"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
|
||||
error = ALERTS["err_no_dataset"][lang]
|
||||
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
|
||||
error = ALERTS["err_no_adapter"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
if adapter_path:
|
||||
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
export_dir=export_dir
|
||||
export_dir=export_dir,
|
||||
export_size=max_shard_size,
|
||||
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||
export_quantization_dataset=export_quantization_dataset
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
export_model(args)
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
|
||||
export_dir = gr.Textbox()
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
@@ -62,18 +79,22 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
engine.manager.get_elem_by_name("top.lang"),
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.model_path"),
|
||||
engine.manager.get_elem_by_name("top.checkpoints"),
|
||||
engine.manager.get_elem_by_name("top.adapter_path"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_name("top.template"),
|
||||
max_shard_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
export_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
export_dir=export_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_quantization_bit=export_quantization_bit,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_dir=export_dir,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from llmtuner.data.template import templates
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
||||
from llmtuner.webui.common import get_model_path, get_template, list_adapters, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -20,24 +20,21 @@ def create_top() -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||
adapter_path = gr.Dropdown(multiselect=True, scale=5, allow_custom_value=True)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||
system_prompt = gr.Textbox(scale=2)
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
flash_attn = gr.Checkbox(value=False)
|
||||
shift_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(
|
||||
@@ -47,13 +44,13 @@ def create_top() -> Dict[str, "Component"]:
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
||||
).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
@@ -61,14 +58,12 @@ def create_top() -> Dict[str, "Component"]:
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
checkpoints=checkpoints,
|
||||
adapter_path=adapter_path,
|
||||
refresh_btn=refresh_btn,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
llama_tab=llama_tab,
|
||||
rope_scaling=rope_scaling,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling
|
||||
shift_attn=shift_attn
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.common import list_adapters, list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import gen_plot
|
||||
|
||||
@@ -60,21 +60,21 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
|
||||
))
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Accordion(label="Extra config", open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||
|
||||
with gr.Column():
|
||||
train_on_prompt = gr.Checkbox(value=False)
|
||||
upcast_layernorm = gr.Checkbox(value=False)
|
||||
|
||||
input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
|
||||
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, train_on_prompt, upcast_layernorm})
|
||||
elem_dict.update(dict(
|
||||
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
||||
neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
|
||||
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
||||
neftune_alpha=neftune_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
|
||||
))
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
@@ -83,22 +83,22 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=1)
|
||||
additional_target = gr.Textbox(scale=1)
|
||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||
create_new_adapter = gr.Checkbox(scale=1)
|
||||
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training})
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
|
||||
elem_dict.update(dict(
|
||||
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
|
||||
additional_target=additional_target, resume_lora_training=resume_lora_training,
|
||||
additional_target=additional_target, create_new_adapter=create_new_adapter
|
||||
))
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(scale=3)
|
||||
reward_model = gr.Dropdown(scale=3, allow_custom_value=True)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
refresh_btn.click(
|
||||
list_checkpoint,
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False
|
||||
|
||||
@@ -49,7 +49,10 @@ class Engine:
|
||||
else:
|
||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({"train.output_dir": {"value": get_time()}})
|
||||
yield self._form_dict({
|
||||
"train.output_dir": {"value": "train_" + get_time()},
|
||||
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||
})
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
|
||||
@@ -75,4 +75,4 @@ def create_web_demo() -> gr.Blocks:
|
||||
if __name__ == "__main__":
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
|
||||
@@ -33,20 +33,20 @@ LOCALES = {
|
||||
"label": "微调方法"
|
||||
}
|
||||
},
|
||||
"checkpoints": {
|
||||
"adapter_path": {
|
||||
"en": {
|
||||
"label": "Checkpoints"
|
||||
"label": "Adapter path"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型断点"
|
||||
"label": "适配器路径"
|
||||
}
|
||||
},
|
||||
"refresh_btn": {
|
||||
"en": {
|
||||
"value": "Refresh checkpoints"
|
||||
"value": "Refresh adapters"
|
||||
},
|
||||
"zh": {
|
||||
"value": "刷新断点"
|
||||
"value": "刷新适配器"
|
||||
}
|
||||
},
|
||||
"advanced_tab": {
|
||||
@@ -77,22 +77,12 @@ LOCALES = {
|
||||
"info": "构建提示词时使用的模板"
|
||||
}
|
||||
},
|
||||
"system_prompt": {
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "System prompt (optional)",
|
||||
"info": "A sequence used as the default system prompt."
|
||||
"label": "RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "系统提示词(非必填)",
|
||||
"info": "默认使用的系统提示词"
|
||||
}
|
||||
},
|
||||
"llama_tab": {
|
||||
"en": {
|
||||
"label": "Model configurations (LLaMA only)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型设置(仅LLaMA)"
|
||||
"label": "RoPE 插值方法"
|
||||
}
|
||||
},
|
||||
"flash_attn": {
|
||||
@@ -111,14 +101,6 @@ LOCALES = {
|
||||
"label": "使用 shift short attention (S^2-Attn)"
|
||||
}
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RoPE 插值方法"
|
||||
}
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
@@ -132,7 +114,7 @@ LOCALES = {
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
"info": "Path of the data directory."
|
||||
"info": "Path to the data directory."
|
||||
},
|
||||
"zh": {
|
||||
"label": "数据路径",
|
||||
@@ -303,6 +285,14 @@ LOCALES = {
|
||||
"info": "验证集占全部样本的百分比。"
|
||||
}
|
||||
},
|
||||
"extra_tab": {
|
||||
"en": {
|
||||
"label": "Extra configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "其它参数设置"
|
||||
}
|
||||
},
|
||||
"logging_steps": {
|
||||
"en": {
|
||||
"label": "Logging steps",
|
||||
@@ -333,7 +323,7 @@ LOCALES = {
|
||||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
},
|
||||
"neft_alpha": {
|
||||
"neftune_alpha": {
|
||||
"en": {
|
||||
"label": "NEFTune Alpha",
|
||||
"info": "Magnitude of noise adding to embedding vectors."
|
||||
@@ -411,14 +401,14 @@ LOCALES = {
|
||||
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"resume_lora_training": {
|
||||
"create_new_adapter": {
|
||||
"en": {
|
||||
"label": "Resume LoRA training",
|
||||
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
|
||||
"label": "Create new adapter",
|
||||
"info": "Whether to create a new adapter with randomly initialized weight or not."
|
||||
},
|
||||
"zh": {
|
||||
"label": "继续上次的训练",
|
||||
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
|
||||
"label": "新建适配器",
|
||||
"info": "是否创建一个经过随机初始化的新适配器。"
|
||||
}
|
||||
},
|
||||
"rlhf_tab": {
|
||||
@@ -442,11 +432,11 @@ LOCALES = {
|
||||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
"info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
|
||||
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "奖励模型",
|
||||
"info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)"
|
||||
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"
|
||||
}
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
@@ -475,12 +465,12 @@ LOCALES = {
|
||||
},
|
||||
"output_dir": {
|
||||
"en": {
|
||||
"label": "Checkpoint name",
|
||||
"info": "Directory to save checkpoint."
|
||||
"label": "Output dir",
|
||||
"info": "Directory for saving results."
|
||||
},
|
||||
"zh": {
|
||||
"label": "断点名称",
|
||||
"info": "保存模型断点的文件夹名称。"
|
||||
"label": "输出目录",
|
||||
"info": "保存结果的路径。"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
@@ -595,6 +585,36 @@ LOCALES = {
|
||||
"label": "温度系数"
|
||||
}
|
||||
},
|
||||
"max_shard_size": {
|
||||
"en": {
|
||||
"label": "Max shard size (GB)",
|
||||
"info": "The maximum size for a model file."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大分块大小(GB)",
|
||||
"info": "单个模型文件的最大大小。"
|
||||
}
|
||||
},
|
||||
"export_quantization_bit": {
|
||||
"en": {
|
||||
"label": "Export quantization bit.",
|
||||
"info": "Quantizing the exported model."
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出量化等级",
|
||||
"info": "量化导出模型。"
|
||||
}
|
||||
},
|
||||
"export_quantization_dataset": {
|
||||
"en": {
|
||||
"label": "Export quantization dataset.",
|
||||
"info": "The calibration dataset used for quantization."
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出量化数据集",
|
||||
"info": "量化过程中使用的校准数据集。"
|
||||
}
|
||||
},
|
||||
"export_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
@@ -605,16 +625,6 @@ LOCALES = {
|
||||
"info": "保存导出模型的文件夹路径。"
|
||||
}
|
||||
},
|
||||
"max_shard_size": {
|
||||
"en": {
|
||||
"label": "Max shard size (GB)",
|
||||
"info": "The maximum size for a model file."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大分块大小(GB)",
|
||||
"info": "模型文件的最大大小。"
|
||||
}
|
||||
},
|
||||
"export_btn": {
|
||||
"en": {
|
||||
"value": "Export"
|
||||
@@ -647,9 +657,9 @@ ALERTS = {
|
||||
"en": "Please choose a dataset.",
|
||||
"zh": "请选择数据集。"
|
||||
},
|
||||
"err_no_checkpoint": {
|
||||
"en": "Please select a checkpoint.",
|
||||
"zh": "请选择断点。"
|
||||
"err_no_adapter": {
|
||||
"en": "Please select an adapter.",
|
||||
"zh": "请选择一个适配器。"
|
||||
},
|
||||
"err_no_export_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
|
||||
@@ -21,11 +21,10 @@ class Manager:
|
||||
self.all_elems["top"]["lang"],
|
||||
self.all_elems["top"]["model_name"],
|
||||
self.all_elems["top"]["model_path"],
|
||||
self.all_elems["top"]["checkpoints"],
|
||||
self.all_elems["top"]["adapter_path"],
|
||||
self.all_elems["top"]["finetuning_type"],
|
||||
self.all_elems["top"]["quantization_bit"],
|
||||
self.all_elems["top"]["template"],
|
||||
self.all_elems["top"]["system_prompt"],
|
||||
self.all_elems["top"]["flash_attn"],
|
||||
self.all_elems["top"]["shift_attn"],
|
||||
self.all_elems["top"]["rope_scaling"]
|
||||
|
||||
@@ -86,23 +86,22 @@ class Runner:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_train=True,
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
@@ -119,24 +118,21 @@ class Runner:
|
||||
logging_steps=get("train.logging_steps"),
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
neft_alpha=get("train.neft_alpha"),
|
||||
neftune_noise_alpha=get("train.neftune_alpha"),
|
||||
train_on_prompt=get("train.train_on_prompt"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") if get("train.additional_target") else None,
|
||||
resume_lora_training=get("train.resume_lora_training"),
|
||||
create_new_adapter=get("train.create_new_adapter"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
||||
)
|
||||
args[get("train.compute_type")] = True
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["resume_lora_training"] = (args["quantization_bit"] is not None)
|
||||
|
||||
if args["quantization_bit"] is not None:
|
||||
args["upcast_layernorm"] = True
|
||||
args["create_new_adapter"] = (args["quantization_bit"] is None)
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(
|
||||
@@ -159,28 +155,22 @@ class Runner:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||
) for ckpt in get("top.checkpoints")])
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||
)
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_eval=True,
|
||||
predict_with_generate=True,
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
@@ -189,10 +179,11 @@ class Runner:
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
max_samples=int(get("eval.max_samples")),
|
||||
per_device_eval_batch_size=get("eval.batch_size"),
|
||||
predict_with_generate=True,
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=output_dir
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
|
||||
)
|
||||
|
||||
if get("eval.predict"):
|
||||
@@ -242,6 +233,7 @@ class Runner:
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||
))
|
||||
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
|
||||
@@ -44,9 +44,10 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
|
||||
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
|
||||
for k, v in args.items():
|
||||
if v is not None and v != "":
|
||||
if v is not None and v is not False and v != "":
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
cmd_text = "\\\n".join(cmd_lines)
|
||||
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||
|
||||
@@ -4,7 +4,7 @@ from llmtuner import create_ui
|
||||
def main():
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,7 +4,7 @@ from llmtuner import create_web_demo
|
||||
def main():
|
||||
demo = create_web_demo()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -26,7 +26,7 @@ def calculate_lr(
|
||||
cutoff_len: int, # i.e. maximum input length during training
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||
dataset_dir: Optional[str] = "data"
|
||||
dataset_dir: Optional[str] = "../data"
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
||||
stage="sft",
|
||||
@@ -38,7 +38,7 @@ def calculate_lr(
|
||||
output_dir="dummy_dir"
|
||||
))
|
||||
trainset = get_dataset(model_args, data_args)
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||
trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft")
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
dataloader = DataLoader(
|
||||
|
||||
77
tests/loftq_init.py
Normal file
77
tests/loftq_init.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# coding=utf-8
|
||||
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
||||
|
||||
import os
|
||||
import fire
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||
|
||||
|
||||
class Shell(nn.Module):
|
||||
|
||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]):
|
||||
parent_name = ".".join(name.split(".")[:-1])
|
||||
child_name = name.split(".")[-1]
|
||||
parent_module = model.get_submodule(parent_name)
|
||||
child_module = getattr(parent_module, child_name)
|
||||
base_layer = getattr(child_module, "base_layer")
|
||||
weight = getattr(base_layer, "weight", None)
|
||||
bias = getattr(base_layer, "bias", None)
|
||||
setattr(parent_module, child_name, Shell(weight, bias))
|
||||
|
||||
print("Model unwrapped.")
|
||||
|
||||
|
||||
def quantize_loftq(
|
||||
model_name_or_path: str,
|
||||
save_dir: str,
|
||||
loftq_bits: Optional[int] = 4,
|
||||
loftq_iter: Optional[int] = 1,
|
||||
lora_alpha: Optional[int] = None,
|
||||
lora_rank: Optional[int] = 16,
|
||||
lora_target: Optional[str] = "q_proj,v_proj"
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=True,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=0.1,
|
||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config
|
||||
)
|
||||
|
||||
# Init LoftQ model
|
||||
lora_model = get_peft_model(model, lora_config)
|
||||
base_model = lora_model.get_base_model()
|
||||
|
||||
# Save LoftQ model
|
||||
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
|
||||
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
|
||||
lora_model.save_pretrained(os.path.join(save_dir, "adapters"))
|
||||
|
||||
# Save base model
|
||||
unwrap_model(base_model)
|
||||
base_model.save_pretrained(save_dir)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(quantize_loftq)
|
||||
Reference in New Issue
Block a user