Compare commits
146 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cef844079 | ||
|
|
4dcd47100d | ||
|
|
a412b4ed4a | ||
|
|
544a6259b6 | ||
|
|
c501f377dd | ||
|
|
cb8b8f40cd | ||
|
|
70bed8ad8f | ||
|
|
51f776ae2a | ||
|
|
697bc20941 | ||
|
|
1480e3a88f | ||
|
|
19029d5b0f | ||
|
|
7773ac0ead | ||
|
|
23b881bff1 | ||
|
|
10a6c395bb | ||
|
|
f9a7732a1f | ||
|
|
c37582af02 | ||
|
|
ece67f8c7f | ||
|
|
e1838e76fe | ||
|
|
2eede9ffd6 | ||
|
|
a6f6b406b3 | ||
|
|
279439abbe | ||
|
|
13117b69d7 | ||
|
|
5d03ac642d | ||
|
|
5062ee547e | ||
|
|
59817c27e3 | ||
|
|
759bee48d2 | ||
|
|
514ffafc12 | ||
|
|
8b2a735c14 | ||
|
|
10d59e9e4a | ||
|
|
058ed5e607 | ||
|
|
110c2ce2a5 | ||
|
|
c425436676 | ||
|
|
266fe908e3 | ||
|
|
dbd905438b | ||
|
|
d64c87f928 | ||
|
|
29eebef696 | ||
|
|
7bfbcb1fe3 | ||
|
|
9b210cf4b3 | ||
|
|
f74e640565 | ||
|
|
d1d08d066a | ||
|
|
6be321b5da | ||
|
|
3c792174db | ||
|
|
9aeb88c426 | ||
|
|
00e2a272ef | ||
|
|
5142349661 | ||
|
|
0e3cc52327 | ||
|
|
6c1db2d012 | ||
|
|
12c51655ce | ||
|
|
36be12a3b7 | ||
|
|
21fac4c98c | ||
|
|
83404c4fa9 | ||
|
|
12f852b8d4 | ||
|
|
a88873116a | ||
|
|
7cfcd69c64 | ||
|
|
a5eabbe933 | ||
|
|
aa25716a5d | ||
|
|
94c8219575 | ||
|
|
ad24a2a0c9 | ||
|
|
c05027d14a | ||
|
|
5420905a2e | ||
|
|
03f2e3284a | ||
|
|
d2bb1b3a6b | ||
|
|
35c4a2c212 | ||
|
|
1e4010a1fb | ||
|
|
1451297c78 | ||
|
|
0b99b13786 | ||
|
|
f5edbf2b49 | ||
|
|
ab6dc0ea30 | ||
|
|
79d34ce0f3 | ||
|
|
1d2e372a8e | ||
|
|
f6a53d83c8 | ||
|
|
4ec56dd958 | ||
|
|
ba06eb65ca | ||
|
|
be716972fe | ||
|
|
719585a128 | ||
|
|
348f29aa50 | ||
|
|
c8fe3f544b | ||
|
|
0f1ad7140f | ||
|
|
233e167f68 | ||
|
|
1d341dcd83 | ||
|
|
d16561e7a4 | ||
|
|
f8e219dc81 | ||
|
|
3365cc8cf0 | ||
|
|
3a5e68b7d9 | ||
|
|
0cb596fee1 | ||
|
|
b3b5b530d1 | ||
|
|
9225c15c88 | ||
|
|
abd9fed445 | ||
|
|
44cda2eece | ||
|
|
8397808d1d | ||
|
|
9e1bd6420d | ||
|
|
619264c854 | ||
|
|
1ebac62e3d | ||
|
|
ce9bdb3509 | ||
|
|
0c8d6369ac | ||
|
|
bee796f6b5 | ||
|
|
9f6349a333 | ||
|
|
171a029c5e | ||
|
|
eaefaa0fe0 | ||
|
|
d301f0a64b | ||
|
|
0a1578e4e3 | ||
|
|
a4167fd925 | ||
|
|
42084e08ae | ||
|
|
9d23f5dc89 | ||
|
|
5978427ae0 | ||
|
|
c7c216069c | ||
|
|
cde9d1b917 | ||
|
|
96213f04b0 | ||
|
|
7ecea08b9b | ||
|
|
191971865d | ||
|
|
ff4f587dd9 | ||
|
|
de728d0371 | ||
|
|
d08e09642d | ||
|
|
351493b183 | ||
|
|
86ab47e121 | ||
|
|
6dd6b3e396 | ||
|
|
5f1418a68b | ||
|
|
7b97a79efc | ||
|
|
ce4f653121 | ||
|
|
b053c6454e | ||
|
|
ebf0f4a77c | ||
|
|
efa808069a | ||
|
|
b5c5283dd6 | ||
|
|
b638c65519 | ||
|
|
d4d471450f | ||
|
|
3144bdec2c | ||
|
|
c6d6c4c209 | ||
|
|
f5f1589662 | ||
|
|
276f2cb24e | ||
|
|
952b785bb3 | ||
|
|
72dd676208 | ||
|
|
dfaa31e991 | ||
|
|
86556b1c74 | ||
|
|
0c80751e87 | ||
|
|
9338f878a3 | ||
|
|
fde3d91242 | ||
|
|
19adfb88a9 | ||
|
|
daaafa900a | ||
|
|
0dcc9e0bca | ||
|
|
aeec78b35c | ||
|
|
c991654cb4 | ||
|
|
f328413646 | ||
|
|
106a0104da | ||
|
|
5486ea09e3 | ||
|
|
31bbbb6d13 | ||
|
|
1a77de82fa |
130
README.md
130
README.md
@@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@@ -43,10 +43,10 @@ Choose your path:
|
||||
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||
@@ -68,14 +68,24 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See `examples/lora_single_gpu/sft_mllm.sh` for usage.
|
||||
|
||||
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
|
||||
|
||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
|
||||
|
||||
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
|
||||
|
||||
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
|
||||
|
||||
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
||||
|
||||
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
|
||||
|
||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
|
||||
@@ -102,7 +112,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||
|
||||
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
|
||||
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||
|
||||
@@ -126,32 +136,38 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence.
|
||||
>
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
>
|
||||
> Remember to use the **SAME** template in training and inference.
|
||||
|
||||
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
||||
|
||||
@@ -222,6 +238,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
@@ -241,6 +258,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
|
||||
</details>
|
||||
@@ -275,16 +293,15 @@ huggingface-cli login
|
||||
|
||||
\* *estimated*
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
||||
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
||||
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -305,7 +322,7 @@ cd LLaMA-Factory
|
||||
pip install -e .[metrics]
|
||||
```
|
||||
|
||||
Extra dependencies available: deepspeed, metrics, unsloth, galore, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
Extra dependencies available: deepspeed, metrics, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
|
||||
<details><summary>For Windows users</summary>
|
||||
|
||||
@@ -319,7 +336,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
||||
|
||||
</details>
|
||||
|
||||
### LLaMA Board GUI
|
||||
### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
|
||||
@@ -328,9 +345,20 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
|
||||
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
|
||||
python src/train_web.py # or python -m llmtuner.webui.interface
|
||||
```
|
||||
|
||||
<details><summary>For Alibaba Cloud users</summary>
|
||||
|
||||
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
|
||||
|
||||
```bash
|
||||
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Use Docker
|
||||
|
||||
```bash
|
||||
@@ -360,7 +388,7 @@ docker compose -f ./docker-compose.yml up -d
|
||||
|
||||
</details>
|
||||
|
||||
### Command Line Interface
|
||||
### Train with Command Line Interface
|
||||
|
||||
See [examples/README.md](examples/README.md) for usage.
|
||||
|
||||
@@ -370,13 +398,13 @@ Use `python src/train_bash.py -h` to display arguments description.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
|
||||
--template mistral \
|
||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--template llama3 \
|
||||
--infer_backend vllm \
|
||||
--vllm_enforce_eager
|
||||
```
|
||||
|
||||
### Use ModelScope Hub
|
||||
### Download from ModelScope Hub
|
||||
|
||||
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
|
||||
|
||||
@@ -384,7 +412,7 @@ If you have trouble with downloading models and datasets from Hugging Face, you
|
||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `modelscope/Llama-2-7b-ms`.
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||
|
||||
## Projects using LLaMA Factory
|
||||
|
||||
@@ -413,8 +441,14 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
@@ -427,7 +461,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
142
README_zh.md
142
README_zh.md
@@ -5,13 +5,13 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
|
||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||
|
||||
@@ -23,7 +23,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
选择你的打开方式:
|
||||
|
||||
- **Colab**:https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **本地机器**:请见[如何使用](#如何使用)
|
||||
|
||||
## 目录
|
||||
@@ -43,10 +43,10 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||
@@ -68,14 +68,24 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 `examples/lora_single_gpu/sft_mllm.sh`。
|
||||
|
||||
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
|
||||
|
||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`。
|
||||
|
||||
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`。
|
||||
|
||||
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`。
|
||||
|
||||
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
|
||||
|
||||
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora`。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。
|
||||
|
||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。
|
||||
@@ -102,7 +112,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||
|
||||
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn fa2` 参数以启用 FlashAttention-2。
|
||||
|
||||
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||
|
||||
@@ -126,32 +136,38 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以得到更好的效果。
|
||||
>
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
>
|
||||
> 请务必在训练和推理时使用**完全一致**的模板。
|
||||
|
||||
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
||||
|
||||
@@ -222,6 +238,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
@@ -241,6 +258,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
|
||||
</details>
|
||||
@@ -275,16 +293,15 @@ huggingface-cli login
|
||||
|
||||
\* *估算值*
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||
| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
||||
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
||||
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
||||
|
||||
## 如何使用
|
||||
|
||||
@@ -305,7 +322,7 @@ cd LLaMA-Factory
|
||||
pip install -e .[metrics]
|
||||
```
|
||||
|
||||
可选的额外依赖项:deepspeed、metrics、unsloth、galore、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
可选的额外依赖项:deepspeed、metrics、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
|
||||
<details><summary>Windows 用户指南</summary>
|
||||
|
||||
@@ -319,18 +336,29 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
</details>
|
||||
|
||||
### LLaMA Board 可视化界面
|
||||
### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行分布式训练。
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。
|
||||
|
||||
#### 使用本地环境
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
|
||||
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
|
||||
python src/train_web.py # 或 python -m llmtuner.webui.interface
|
||||
```
|
||||
|
||||
<details><summary>阿里云用户指南</summary>
|
||||
|
||||
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
|
||||
|
||||
```bash
|
||||
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### 使用 Docker
|
||||
|
||||
```bash
|
||||
@@ -360,23 +388,23 @@ docker compose -f ./docker-compose.yml up -d
|
||||
|
||||
</details>
|
||||
|
||||
### 命令行接口
|
||||
### 利用命令行接口训练
|
||||
|
||||
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
|
||||
|
||||
使用 `python src/train_bash.py -h` 查看参数文档。
|
||||
您可以执行 `python src/train_bash.py -h` 来查看参数文档。
|
||||
|
||||
### 使用 OpenAI 风格 API 和 vLLM 部署
|
||||
### 利用 vLLM 部署 OpenAI API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
|
||||
--template mistral \
|
||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--template llama3 \
|
||||
--infer_backend vllm \
|
||||
--vllm_enforce_eager
|
||||
```
|
||||
|
||||
### 使用魔搭社区
|
||||
### 从魔搭社区下载
|
||||
|
||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
|
||||
@@ -384,11 +412,11 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
```
|
||||
|
||||
将 `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `modelscope/Llama-2-7b-ms`。
|
||||
将 `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
|
||||
|
||||
<details><summary>点击显示</summary>
|
||||
|
||||
@@ -413,8 +441,14 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
@@ -427,7 +461,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
"history": "the column name in the dataset containing the histories. (default: None)",
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||
"tools": "the column name in the dataset containing the tool description. (default: None)"
|
||||
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
||||
"images": "the column name in the dataset containing the image inputs. (default: None)"
|
||||
},
|
||||
"tags (optional, used for the sharegpt format)": {
|
||||
"role_tag": "the key in the message represents the identity. (default: from)",
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)",
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||
"tools": "数据集代表工具描述的表头名称(默认:None)"
|
||||
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
||||
"images": "数据集代表图像输入的表头名称(默认:None)"
|
||||
},
|
||||
"tags(可选,用于 sharegpt 格式)": {
|
||||
"role_tag": "消息中代表发送者身份的键名(默认:from)",
|
||||
|
||||
@@ -1 +1 @@
|
||||
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
|
||||
a97cf9475291591843976554878568e046d8a46d
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
@@ -22,31 +23,19 @@ _URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
features = datasets.Features(
|
||||
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": file_path
|
||||
}
|
||||
)
|
||||
]
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
@@ -58,7 +47,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
assist_idx = prompt.rfind("Assistant:")
|
||||
human_idx = prompt.rfind("Human:")
|
||||
query = prompt[human_idx+6:assist_idx].strip()
|
||||
query = prompt[human_idx + 6 : assist_idx].strip()
|
||||
prompt = prompt[:human_idx].strip()
|
||||
conversations.insert(0, {"from": "gpt", "value": response})
|
||||
conversations.insert(0, {"from": "human", "value": query})
|
||||
@@ -67,8 +56,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
assist_idx = prompt.rfind("Assistant:")
|
||||
human_idx = prompt.rfind("Human:")
|
||||
if human_idx != -1:
|
||||
old_query = prompt[human_idx+6:assist_idx].strip()
|
||||
old_resp = prompt[assist_idx+10:].strip()
|
||||
old_query = prompt[human_idx + 6 : assist_idx].strip()
|
||||
old_resp = prompt[assist_idx + 10 :].strip()
|
||||
conversations.insert(0, {"from": "gpt", "value": old_resp})
|
||||
conversations.insert(0, {"from": "human", "value": old_query})
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset."
|
||||
_CITATION = ""
|
||||
@@ -11,34 +12,24 @@ _URL = "examples.json"
|
||||
|
||||
|
||||
class ExampleDataset(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"input": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
})
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"input": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": file_path
|
||||
}
|
||||
)
|
||||
]
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||
_CITATION = ""
|
||||
@@ -14,50 +16,37 @@ _URLS = {
|
||||
_URL + "harmless-base/train.jsonl.gz",
|
||||
_URL + "helpful-base/train.jsonl.gz",
|
||||
_URL + "helpful-online/train.jsonl.gz",
|
||||
_URL + "helpful-rejection-sampled/train.jsonl.gz"
|
||||
_URL + "helpful-rejection-sampled/train.jsonl.gz",
|
||||
],
|
||||
"test": [
|
||||
_URL + "harmless-base/test.jsonl.gz",
|
||||
_URL + "helpful-base/test.jsonl.gz",
|
||||
_URL + "helpful-online/test.jsonl.gz",
|
||||
_URL + "helpful-rejection-sampled/test.jsonl.gz"
|
||||
]
|
||||
_URL + "helpful-rejection-sampled/test.jsonl.gz",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Sequence(datasets.Value("string")),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
})
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Sequence(datasets.Value("string")),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download_and_extract(_URLS)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepaths": file_path["train"]
|
||||
}
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepaths": file_path["test"]
|
||||
}
|
||||
)
|
||||
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
|
||||
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
@@ -70,12 +59,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
rejected = data["rejected"]
|
||||
|
||||
assist_idx = rejected.rfind("\n\nAssistant: ")
|
||||
r_reject = rejected[assist_idx+13:].strip()
|
||||
r_reject = rejected[assist_idx + 13 :].strip()
|
||||
assist_idx = chosen.rfind("\n\nAssistant: ")
|
||||
r_accept = chosen[assist_idx+13:].strip()
|
||||
r_accept = chosen[assist_idx + 13 :].strip()
|
||||
|
||||
human_idx = chosen.rfind("\n\nHuman: ")
|
||||
query = chosen[human_idx+9:assist_idx].strip()
|
||||
query = chosen[human_idx + 9 : assist_idx].strip()
|
||||
prompt = chosen[:human_idx]
|
||||
history = []
|
||||
|
||||
@@ -83,16 +72,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
assist_idx = prompt.rfind("\n\nAssistant: ")
|
||||
human_idx = prompt.rfind("\n\nHuman: ")
|
||||
if human_idx != -1:
|
||||
old_query = prompt[human_idx+9:assist_idx].strip()
|
||||
old_resp = prompt[assist_idx+13:].strip()
|
||||
old_query = prompt[human_idx + 9 : assist_idx].strip()
|
||||
old_resp = prompt[assist_idx + 13 :].strip()
|
||||
history.insert(0, (old_query, old_resp))
|
||||
else:
|
||||
break
|
||||
prompt = prompt[:human_idx]
|
||||
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": [r_accept, r_reject],
|
||||
"history": history
|
||||
}
|
||||
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
|
||||
key += 1
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||
|
||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
||||
@@ -24,31 +26,19 @@ _BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jso
|
||||
|
||||
|
||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
features = datasets.Features(
|
||||
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepaths": file_paths
|
||||
}
|
||||
)
|
||||
]
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
for filepath in filepaths:
|
||||
@@ -56,7 +46,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
for row in f:
|
||||
try:
|
||||
data = json.loads(row)
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
key: int = data["id"]
|
||||
content: List[str] = data["data"]
|
||||
@@ -64,8 +54,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
content.pop(-1)
|
||||
if len(content) < 2:
|
||||
continue
|
||||
conversations = [{
|
||||
"from": "human" if i % 2 == 0 else "gpt",
|
||||
"value": content[i]
|
||||
} for i in range(len(content))]
|
||||
conversations = [
|
||||
{"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
|
||||
]
|
||||
yield key, {"conversations": conversations}
|
||||
|
||||
@@ -3,41 +3,48 @@ We provide diverse examples about fine-tuning LLMs.
|
||||
```
|
||||
examples/
|
||||
├── lora_single_gpu/
|
||||
│ ├── pretrain.sh: Do pre-training
|
||||
│ ├── sft.sh: Do supervised fine-tuning
|
||||
│ ├── reward.sh: Do reward modeling
|
||||
│ ├── ppo.sh: Do PPO training
|
||||
│ ├── dpo.sh: Do DPO training
|
||||
│ ├── orpo.sh: Do ORPO training
|
||||
│ ├── pretrain.sh: Do continuous pre-training using LoRA
|
||||
│ ├── sft.sh: Do supervised fine-tuning using LoRA
|
||||
│ ├── reward.sh: Do reward modeling using LoRA
|
||||
│ ├── ppo.sh: Do PPO training using LoRA
|
||||
│ ├── dpo.sh: Do DPO training using LoRA
|
||||
│ ├── orpo.sh: Do ORPO training using LoRA
|
||||
│ ├── sft_mllm.sh: Do supervised fine-tuning on multimodal data using LoRA
|
||||
│ ├── prepare.sh: Save tokenized dataset
|
||||
│ └── predict.sh: Do batch predict
|
||||
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
|
||||
├── qlora_single_gpu/
|
||||
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models
|
||||
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models
|
||||
│ ├── awq.sh: Fine-tune 4-bit AWQ models
|
||||
│ └── aqlm.sh: Fine-tune 2-bit AQLM models
|
||||
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models using QLoRA
|
||||
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models using QLoRA
|
||||
│ ├── awq.sh: Fine-tune 4-bit AWQ models using QLoRA
|
||||
│ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
|
||||
├── lora_multi_gpu/
|
||||
│ ├── single_node.sh: Fine-tune model with Accelerate on single node
|
||||
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes
|
||||
│ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
|
||||
│ ├── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
|
||||
│ └── ds_zero3.sh: Fine-tune model with DeepSpeed ZeRO-3 using LoRA (weight sharding)
|
||||
├── full_multi_gpu/
|
||||
│ ├── single_node.sh: Fine-tune model with DeepSpeed on single node
|
||||
│ └── multi_node.sh: Fine-tune model with DeepSpeed on multiple nodes
|
||||
│ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
|
||||
│ ├── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
|
||||
│ └── predict.sh: Do parallel batch predict and compute BLEU and ROUGE scores after full tuning
|
||||
├── merge_lora/
|
||||
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
|
||||
│ └── quantize.sh: Quantize fine-tuned model with AutoGPTQ
|
||||
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
|
||||
├── inference/
|
||||
│ ├── cli_demo.sh: Launch a command line interface
|
||||
│ ├── api_demo.sh: Launch an OpenAI-style API
|
||||
│ ├── web_demo.sh: Launch a web interface
|
||||
│ └── evaluate.sh: Evaluate model on the MMLU benchmark
|
||||
│ ├── cli_demo.sh: Chat with fine-tuned model in the CLI with LoRA adapters
|
||||
│ ├── api_demo.sh: Chat with fine-tuned model in an OpenAI-style API with LoRA adapters
|
||||
│ ├── web_demo.sh: Chat with fine-tuned model in the Web browser with LoRA adapters
|
||||
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
|
||||
└── extras/
|
||||
├── galore/
|
||||
│ └── sft.sh: Fine-tune model with GaLore
|
||||
├── badam/
|
||||
│ └── sft.sh: Fine-tune model with BAdam
|
||||
├── loraplus/
|
||||
│ └── sft.sh: Fine-tune model with LoRA+
|
||||
│ └── sft.sh: Fine-tune model using LoRA+
|
||||
├── mod/
|
||||
│ └── sft.sh: Fine-tune model using Mixture-of-Depths
|
||||
├── llama_pro/
|
||||
│ ├── expand.sh: Expand layers in the model
|
||||
│ └── sft.sh: Fine-tune expanded model
|
||||
│ └── sft.sh: Fine-tune the expanded model
|
||||
└── fsdp_qlora/
|
||||
└── sft.sh: Fine-tune quantized model with FSDP
|
||||
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
|
||||
```
|
||||
|
||||
@@ -1,43 +1,50 @@
|
||||
我们提供了多样化的示例脚本。
|
||||
我们提供了多样化的大模型微调示例脚本。
|
||||
|
||||
```
|
||||
examples/
|
||||
├── lora_single_gpu/
|
||||
│ ├── pretrain.sh: 进行预训练
|
||||
│ ├── sft.sh: 进行指令监督微调
|
||||
│ ├── reward.sh: 进行奖励模型训练
|
||||
│ ├── ppo.sh: 进行 PPO 训练
|
||||
│ ├── dpo.sh: 进行 DPO 训练
|
||||
│ ├── orpo.sh: 进行 ORPO 训练
|
||||
│ ├── pretrain.sh: 基于 LoRA 进行增量预训练
|
||||
│ ├── sft.sh: 基于 LoRA 进行指令监督微调
|
||||
│ ├── reward.sh: 基于 LoRA 进行奖励模型训练
|
||||
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
|
||||
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
|
||||
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
|
||||
│ ├── sft_mllm.sh: 基于 LoRA 进行多模态指令监督微调
|
||||
│ ├── prepare.sh: 保存预处理后的数据集
|
||||
│ └── predict.sh: 进行批量预测
|
||||
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
|
||||
├── qlora_single_gpu/
|
||||
│ ├── bitsandbytes.sh: 微调 4/8 比特 BNB 模型
|
||||
│ ├── gptq.sh: 微调 4/8 比特 GPTQ 模型
|
||||
│ ├── awq.sh: 微调 4 比特 AWQ 模型
|
||||
│ └── aqlm.sh: 微调 2 比特 AQLM 模型
|
||||
│ ├── bitsandbytes.sh: 基于 QLoRA 微调 4/8 比特 BNB 模型
|
||||
│ ├── gptq.sh: 基于 QLoRA 微调 4/8 比特 GPTQ 模型
|
||||
│ ├── awq.sh: 基于 QLoRA 微调 4 比特 AWQ 模型
|
||||
│ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
|
||||
├── lora_multi_gpu/
|
||||
│ ├── single_node.sh: 使用 Accelerate 进行单节点训练
|
||||
│ └── multi_node.sh: 使用 Accelerate 进行多节点训练
|
||||
│ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
|
||||
│ ├── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
|
||||
│ └── ds_zero3.sh: 使用 DeepSpeed ZeRO-3 进行 LoRA 训练(拆分权重)
|
||||
├── full_multi_gpu/
|
||||
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点训练
|
||||
│ └── multi_node.sh: 使用 DeepSpeed 进行多节点训练
|
||||
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
|
||||
│ ├── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
|
||||
│ └── predict.sh: 基于全量训练进行多卡批量预测并计算 BLEU 和 ROUGE 分数
|
||||
├── merge_lora/
|
||||
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
|
||||
│ └── quantize.sh: 使用 AutoGPTQ 量化模型
|
||||
│ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型
|
||||
├── inference/
|
||||
│ ├── cli_demo.sh: 启动命令行推理接口
|
||||
│ ├── api_demo.sh: 启动 OpenAI 风格 API
|
||||
│ ├── web_demo.sh: 启动浏览器推理接口
|
||||
│ └── evaluate.sh: 在 MMLU 数据集上评测模型
|
||||
│ ├── cli_demo.sh: 启动 LoRA 模型的命令行推理接口
|
||||
│ ├── api_demo.sh: 启动 LoRA 模型的 OpenAI 风格 API
|
||||
│ ├── web_demo.sh: 启动 LoRA 模型的浏览器推理接口
|
||||
│ └── evaluate.sh: 在 MMLU/CMMLU/C-Eval 数据集上评测 LoRA 模型
|
||||
└── extras/
|
||||
├── galore/
|
||||
│ └── sft.sh: 使用 GaLore 训练模型
|
||||
├── badam/
|
||||
│ └── sft.sh: 使用 BAdam 训练模型
|
||||
├── loraplus/
|
||||
│ └── sft.sh: 使用 LoRA+ 训练模型
|
||||
├── mod/
|
||||
│ └── sft.sh: 使用深度混合训练模型
|
||||
├── llama_pro/
|
||||
│ ├── expand.sh: 扩展模型中的层
|
||||
│ └── sft.sh: 训练扩展后的模型
|
||||
└── fsdp_qlora/
|
||||
└── sft.sh: 使用 FSDP 微调量化模型
|
||||
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
|
||||
```
|
||||
|
||||
@@ -9,7 +9,7 @@ main_process_port: 29555
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 2 # the number of nodes
|
||||
num_processes: 16 # the number of GPUs in all nodes
|
||||
num_processes: 8 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
||||
@@ -9,7 +9,7 @@ main_process_port: 29555
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 2 # the number of nodes
|
||||
num_processes: 16 # the number of GPUs in all nodes
|
||||
num_processes: 8 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
||||
35
examples/extras/badam/sft.sh
Normal file
35
examples/extras/badam/sft.sh
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--use_badam \
|
||||
--badam_switch_mode descending \
|
||||
--badam_switch_block_every 50 \
|
||||
--badam_verbose 2 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--pure_bf16
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/bin/bash
|
||||
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA
|
||||
|
||||
pip install "transformers>=4.39.1"
|
||||
pip install "accelerate>=0.28.0"
|
||||
|
||||
@@ -12,6 +12,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--galore_layerwise \
|
||||
--galore_target mlp,self_attn \
|
||||
--galore_rank 128 \
|
||||
--galore_scale 2.0 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
|
||||
@@ -9,6 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--loraplus_lr_ratio 16.0 \
|
||||
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
@@ -29,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16 \
|
||||
--loraplus_lr_ratio 16.0
|
||||
--fp16
|
||||
|
||||
33
examples/extras/mod/sft.sh
Normal file
33
examples/extras/mod/sft.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--mixture_of_depths convert \
|
||||
--output_dir ../../../saves/LLaMA2-7B/mod/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--optim paged_adamw_8bit \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--pure_bf16
|
||||
20
examples/full_multi_gpu/predict.sh
Normal file
20
examples/full_multi_gpu/predict.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/single_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--output_dir ../../saves/LLaMA2-7B/full/predict \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 20 \
|
||||
--predict_with_generate
|
||||
@@ -3,7 +3,7 @@
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template vanilla \
|
||||
--template fewshot \
|
||||
--finetuning_type lora \
|
||||
--task mmlu \
|
||||
--split test \
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/bin/bash
|
||||
# add `--visual_inputs True` to load MLLM
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
||||
33
examples/lora_multi_gpu/ds_zero3.sh
Normal file
33
examples/lora_multi_gpu/ds_zero3.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--ddp_timeout 180000000 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/bin/bash
|
||||
# also launch it on slave machine using slave_config.yaml
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/master_config.yaml \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/single_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
--stage sft \
|
||||
|
||||
33
examples/lora_single_gpu/sft_mllm.sh
Normal file
33
examples/lora_single_gpu/sft_mllm.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path llava-hf/llava-1.5-7b-hf \
|
||||
--visual_inputs \
|
||||
--dataset mllm_demo \
|
||||
--dataset_dir ../../data \
|
||||
--template vicuna \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft_mllm \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 100.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
@@ -1,11 +1,12 @@
|
||||
#!/bin/bash
|
||||
# DO NOT use quantized model or quantization_bit when merging lora weights
|
||||
|
||||
CUDA_VISIBLE_DEVICES= python ../../src/export_model.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--export_dir ../../models/llama2-7b-sft \
|
||||
--export_size 2 \
|
||||
--export_device cpu \
|
||||
--export_legacy_format False
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/bin/bash
|
||||
# NEED TO run `merge.sh` before using this script
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
--model_name_or_path ../../models/llama2-7b-sft \
|
||||
|
||||
@@ -4,7 +4,7 @@ datasets>=2.14.3
|
||||
accelerate>=0.27.2
|
||||
peft>=0.10.0
|
||||
trl>=0.8.1
|
||||
gradio>=4.0.0,<=4.21.0
|
||||
gradio>=4.0.0
|
||||
scipy
|
||||
einops
|
||||
sentencepiece
|
||||
@@ -15,3 +15,4 @@ fastapi
|
||||
sse-starlette
|
||||
matplotlib
|
||||
fire
|
||||
packaging
|
||||
|
||||
@@ -44,8 +44,9 @@ def calculate_lr(
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
|
||||
@@ -32,8 +32,8 @@ def length_cdf(
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
total_num = len(trainset)
|
||||
length_dict = defaultdict(int)
|
||||
for sample in tqdm(trainset["input_ids"]):
|
||||
|
||||
4
setup.py
4
setup.py
@@ -22,9 +22,9 @@ def get_requires():
|
||||
extra_require = {
|
||||
"deepspeed": ["deepspeed>=0.10.0"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
|
||||
"galore": ["galore-torch"],
|
||||
"vllm": ["vllm>=0.3.3"],
|
||||
"badam": ["badam"],
|
||||
"vllm": ["vllm>=0.4.0"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
|
||||
@@ -7,5 +7,5 @@ from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.6.2"
|
||||
__version__ = "0.7.0"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
|
||||
@@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
from ..data import Template
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
@@ -49,6 +47,7 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]: ...
|
||||
|
||||
@@ -58,6 +57,7 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from .vllm_engine import VllmEngine
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
@@ -36,9 +38,10 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def achat(
|
||||
@@ -46,18 +49,20 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
return await self.engine.chat(messages, system, tools, **input_kwargs)
|
||||
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
@@ -70,9 +75,10 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
|
||||
@@ -14,7 +14,9 @@ from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
@@ -30,7 +32,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
self.tokenizer = load_tokenizer(model_args)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.model = load_model(
|
||||
@@ -42,13 +46,18 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _process_args(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
@@ -95,6 +104,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@staticmethod
|
||||
@@ -102,15 +116,17 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _chat(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@@ -135,15 +151,17 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _stream_chat(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@@ -199,6 +217,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@@ -208,11 +227,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.processor,
|
||||
self.template,
|
||||
self.generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
@@ -224,6 +245,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@@ -233,11 +255,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_args = (
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.processor,
|
||||
self.template,
|
||||
self.generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.misc import get_device_count, infer_optim_dtype
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..model import load_tokenizer
|
||||
from ..model import load_config, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
@@ -25,32 +30,59 @@ class VllmEngine(BaseEngine):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
infer_dtype = str(infer_dtype).split(".")[-1]
|
||||
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=model_args.vllm_maxlen,
|
||||
tensor_parallel_size=get_device_count() or 1,
|
||||
gpu_memory_utilization=model_args.vllm_gpu_util,
|
||||
disable_log_stats=True,
|
||||
disable_log_requests=True,
|
||||
enforce_eager=model_args.vllm_enforce_eager,
|
||||
)
|
||||
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
self.tokenizer = load_tokenizer(model_args)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
engine_args = {
|
||||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"download_dir": model_args.cache_dir,
|
||||
"dtype": infer_dtype,
|
||||
"max_model_len": model_args.vllm_maxlen,
|
||||
"tensor_parallel_size": get_device_count() or 1,
|
||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||
"disable_log_stats": True,
|
||||
"disable_log_requests": True,
|
||||
"enforce_eager": model_args.vllm_enforce_eager,
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
}
|
||||
|
||||
if model_args.visual_inputs:
|
||||
# TODO: auto derive from config
|
||||
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
|
||||
self.image_feature_size = 576
|
||||
engine_args["image_input_type"] = "pixel_values"
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
|
||||
engine_args["image_input_shape"] = "1,3,336,336"
|
||||
engine_args["image_feature_size"] = self.image_feature_size
|
||||
|
||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||
else:
|
||||
self.lora_request = None
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
@@ -94,8 +126,21 @@ class VllmEngine(BaseEngine):
|
||||
max_tokens=generating_args["max_new_tokens"],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
result_generator = self.model.generate(
|
||||
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
|
||||
prompt=None,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_ids,
|
||||
lora_request=self.lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
return result_generator
|
||||
|
||||
@@ -107,10 +152,11 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
final_output = request_output
|
||||
|
||||
@@ -132,10 +178,11 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generated_text = ""
|
||||
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
||||
async for result in generator:
|
||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||
generated_text = result.outputs[0].text
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
@@ -13,8 +14,23 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
outputs = []
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
||||
outputs.append(os.path.join(data_args.dataset_dir, image))
|
||||
else:
|
||||
outputs.append(image)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
@@ -44,12 +60,16 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
@@ -84,6 +104,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -96,12 +117,13 @@ def align_dataset(
|
||||
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
system: "..."
|
||||
tools: "..."
|
||||
tools: "...",
|
||||
images: [],
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||
else:
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
features = Features.from_dict(
|
||||
@@ -114,6 +136,7 @@ def align_dataset(
|
||||
],
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
@@ -16,7 +16,7 @@ from .utils import checksum, merge_dataset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
@@ -115,11 +115,12 @@ def load_single_dataset(
|
||||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
@@ -149,7 +150,7 @@ def get_dataset(
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
tokenizer, template, data_args, training_args, stage
|
||||
data_args, training_args, stage, template, tokenizer, processor
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
|
||||
@@ -28,6 +28,7 @@ class DatasetAttr:
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
system: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
""" columns for the alpaca format """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
@@ -105,7 +106,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
column_names = ["system", "images"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_pillow_available
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
@@ -18,6 +26,13 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
# process visual inputs (currently only supports a single image)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0]
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
@@ -48,18 +63,25 @@ def preprocess_pretrain_dataset(
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
@@ -89,14 +111,16 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
@@ -141,17 +165,24 @@ def preprocess_packed_supervised_dataset(
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
@@ -172,22 +203,32 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
@@ -214,6 +255,8 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -244,34 +287,54 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||
preprocess_func = partial(
|
||||
preprocess_pretrain_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_pairwise_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_unsupervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
|
||||
@@ -343,7 +343,7 @@ def get_template_and_fix_tokenizer(
|
||||
name: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = templates["vanilla"] # placeholder
|
||||
template = templates["empty"] # placeholder
|
||||
else:
|
||||
template = templates.get(name, None)
|
||||
if template is None:
|
||||
@@ -385,7 +385,8 @@ _register_template(
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -502,6 +503,7 @@ _register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
@@ -512,6 +514,7 @@ _register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
@@ -526,6 +529,21 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cohere",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
|
||||
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
@@ -534,6 +552,32 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="dbrx",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
||||
"You answer questions based on information available up to that point.\n"
|
||||
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
|
||||
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
|
||||
"from writing to coding (using markdown for code blocks — remember to use ``` with "
|
||||
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
|
||||
"capabilities. You avoid stereotyping and provide balanced perspectives on "
|
||||
"controversial topics. You do not provide song lyrics, poems, or news articles and "
|
||||
"do not divulge details of your training data.)\nThis is your system prompt, "
|
||||
"guiding your responses. Do not reference it, just respond to the user. If you find "
|
||||
"yourself talking about this message, stop. You should be responding appropriately "
|
||||
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
|
||||
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
@@ -566,6 +610,13 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="empty",
|
||||
format_user=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
@@ -574,10 +625,20 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="fewshot",
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
@@ -635,9 +696,36 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llama3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
||||
),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
@@ -669,10 +757,23 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful AI assistant.",
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
@@ -699,13 +800,6 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vanilla",
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
@@ -776,7 +870,7 @@ _register_template(
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
default_system="You are Zephyr, a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -78,9 +78,9 @@ def split_dataset(
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
|
||||
@@ -21,7 +21,7 @@ from .template import get_eval_template
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.tokenizer = load_tokenizer(self.model_args)
|
||||
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||
|
||||
@@ -28,6 +28,10 @@ LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
MLLM_LIST = ["LLaVA1.5"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
@@ -45,6 +49,8 @@ TRAINING_STAGES = {
|
||||
|
||||
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||
@@ -242,6 +248,44 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CommandR-35B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
|
||||
},
|
||||
"CommandR-Plus-104B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
|
||||
},
|
||||
"CommandR-35B-4bit-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
|
||||
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
|
||||
},
|
||||
"CommandR-Plus-104B-4bit-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
|
||||
},
|
||||
},
|
||||
template="cohere",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DBRX-132B-Base": {
|
||||
DownloadSource.DEFAULT: "databricks/dbrx-base",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
|
||||
},
|
||||
"DBRX-132B-Chat": {
|
||||
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
|
||||
},
|
||||
},
|
||||
module="Wqkv",
|
||||
template="dbrx",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeek-LLM-7B-Base": {
|
||||
@@ -262,9 +306,11 @@ register_model_group(
|
||||
},
|
||||
"DeepSeek-Math-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
|
||||
},
|
||||
"DeepSeek-Math-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||
@@ -363,6 +409,23 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGemma-2B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-2b",
|
||||
},
|
||||
"CodeGemma-7B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b",
|
||||
},
|
||||
"CodeGemma-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
@@ -410,6 +473,16 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Jambda-v0.1": {
|
||||
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": {
|
||||
@@ -474,6 +547,42 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA3-8B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
||||
},
|
||||
"LLaMA3-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
||||
},
|
||||
"LLaMA3-8B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
||||
},
|
||||
"LLaMA3-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
||||
},
|
||||
},
|
||||
template="llama3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
||||
},
|
||||
"LLaVA1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||
},
|
||||
},
|
||||
template="vicuna",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B-v0.1": {
|
||||
@@ -499,14 +608,21 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mixtral-8x7B": {
|
||||
"Mixtral-8x7B-v0.1": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
||||
},
|
||||
"Mixtral-8x7B-Chat": {
|
||||
"Mixtral-8x7B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
||||
},
|
||||
"Mixtral-8x22B-v0.1": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
|
||||
},
|
||||
"Mixtral-8x22B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||
},
|
||||
},
|
||||
template="mistral",
|
||||
)
|
||||
@@ -515,18 +631,15 @@ register_model_group(
|
||||
register_model_group(
|
||||
models={
|
||||
"OLMo-1B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-1B",
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
|
||||
},
|
||||
"OLMo-7B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
|
||||
},
|
||||
"OLMo-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
|
||||
"OLMo-1.7-7B": {
|
||||
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
|
||||
},
|
||||
},
|
||||
module="att_proj",
|
||||
template="olmo",
|
||||
)
|
||||
|
||||
|
||||
@@ -534,7 +647,7 @@ register_model_group(
|
||||
models={
|
||||
"OpenChat3.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
|
||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
|
||||
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
|
||||
}
|
||||
},
|
||||
template="openchat",
|
||||
@@ -582,6 +695,22 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi3-3.8B-4k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
||||
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-4k-instruct",
|
||||
},
|
||||
"Phi3-3.8B-128k-Chat": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
||||
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-128k-instruct",
|
||||
},
|
||||
},
|
||||
module="qkv_proj",
|
||||
template="phi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {
|
||||
@@ -684,10 +813,18 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
|
||||
},
|
||||
"Qwen1.5-110B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
|
||||
},
|
||||
"Qwen1.5-Code-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
|
||||
@@ -716,10 +853,18 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
|
||||
},
|
||||
"Qwen1.5-110B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
},
|
||||
"Qwen1.5-Code-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
@@ -772,10 +917,18 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-110B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-Code-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
@@ -809,12 +962,15 @@ register_model_group(
|
||||
models={
|
||||
"StarCoder2-3B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
|
||||
},
|
||||
"StarCoder2-7B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
|
||||
},
|
||||
"StarCoder2-15B": {
|
||||
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -837,17 +993,53 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XuanYuan-6B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
|
||||
},
|
||||
"XuanYuan-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
||||
},
|
||||
"XuanYuan-2-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
||||
},
|
||||
"XuanYuan-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
||||
},
|
||||
"XuanYuan-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
},
|
||||
"XuanYuan-2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||
},
|
||||
"XuanYuan-6B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-6B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||
},
|
||||
"XuanYuan-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
},
|
||||
"XuanYuan-2-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-2-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||
},
|
||||
},
|
||||
template="xuanyuan",
|
||||
@@ -884,6 +1076,30 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
||||
},
|
||||
"XVERSE-MoE-A4.2B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
|
||||
},
|
||||
"XVERSE-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-13B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-13B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-65B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
},
|
||||
},
|
||||
template="xverse",
|
||||
)
|
||||
@@ -976,21 +1192,9 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
|
||||
},
|
||||
"Zephyr-141B-ORPO-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
|
||||
},
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Atom-7B": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
|
||||
},
|
||||
"Atom-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="atom",
|
||||
)
|
||||
|
||||
@@ -66,7 +66,6 @@ def check_dependencies() -> None:
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||
require_version("gradio>=4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
@@ -84,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
elif hasattr(param, "element_size"): # for older pytorch version
|
||||
num_bytes = param.element_size()
|
||||
else:
|
||||
num_bytes = 1
|
||||
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
def _is_package_available(name: str) -> bool:
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
def _get_package_version(name: str) -> str:
|
||||
def _get_package_version(name: str) -> "Version":
|
||||
try:
|
||||
return importlib.metadata.version(name)
|
||||
return version.parse(importlib.metadata.version(name))
|
||||
except Exception:
|
||||
return "0.0.0"
|
||||
return version.parse("0.0.0")
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
@@ -18,13 +25,17 @@ def is_fastapi_availble():
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
|
||||
|
||||
|
||||
def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
|
||||
def is_gradio_available():
|
||||
return _is_package_available("gradio")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _is_package_available("jieba")
|
||||
|
||||
@@ -37,6 +48,10 @@ def is_nltk_available():
|
||||
return _is_package_available("nltk")
|
||||
|
||||
|
||||
def is_pillow_available():
|
||||
return _is_package_available("PIL")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _is_package_available("requests")
|
||||
|
||||
@@ -45,14 +60,14 @@ def is_rouge_available():
|
||||
return _is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_sdpa_available():
|
||||
return _get_package_version("torch") > version.parse("2.1.1")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_unsloth_available():
|
||||
return _is_package_available("unsloth")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
|
||||
@@ -26,11 +26,11 @@ class DataArguments:
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
)
|
||||
reserved_label_len: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
|
||||
@@ -172,7 +172,7 @@ class GaloreArguments:
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="all",
|
||||
@@ -204,7 +204,54 @@ class GaloreArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the BAdam optimizer."},
|
||||
)
|
||||
badam_mode: Literal["layer", "ratio"] = field(
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_block_every: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_update_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
default="adjacent",
|
||||
metadata={
|
||||
"help": """The mode of the mask for BAdam optimizer. \
|
||||
`adjacent` means that the trainable parameters are adjacent to each other, \
|
||||
`scatter` means that trainable parameters are randomly choosed from the weight."""
|
||||
},
|
||||
)
|
||||
badam_verbose: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": """The verbosity level of BAdam optimizer. \
|
||||
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
@@ -256,11 +303,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
@@ -31,11 +31,11 @@ class GeneratingArguments:
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||
)
|
||||
max_length: int = field(
|
||||
default=512,
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||
)
|
||||
max_new_tokens: int = field(
|
||||
default=512,
|
||||
default=1024,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||
)
|
||||
repetition_penalty: float = field(
|
||||
|
||||
@@ -22,7 +22,7 @@ class ModelArguments:
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=False,
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||
)
|
||||
resize_vocab: bool = field(
|
||||
@@ -33,6 +33,10 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
new_special_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
@@ -55,24 +59,32 @@ class ModelArguments:
|
||||
)
|
||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Device map used for loading the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
shift_attn: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
)
|
||||
use_unsloth: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
)
|
||||
visual_inputs: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
@@ -129,6 +141,10 @@ class ModelArguments:
|
||||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_device: str = field(
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
@@ -166,9 +182,15 @@ class ModelArguments:
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.visual_inputs and self.use_unsloth:
|
||||
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
|
||||
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
if self.new_special_tokens is not None: # support multiple special tokens
|
||||
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies
|
||||
from ..extras.packages import is_unsloth_available
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -67,6 +67,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.resize_vocab:
|
||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||
|
||||
@@ -74,6 +77,35 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["Seq2SeqTrainingArguments"] = None,
|
||||
) -> None:
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
|
||||
if model_args.mixture_of_depths is not None:
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.0", "To fix: pip install vllm>=0.4.0")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
require_version("badam", "To fix: pip install badam")
|
||||
|
||||
if finetuning_args.plot_loss:
|
||||
require_version("matplotlib", "To fix: pip install matplotlib")
|
||||
|
||||
if training_args is not None and training_args.predict_with_generate:
|
||||
require_version("jieba", "To fix: pip install jieba")
|
||||
require_version("nltk", "To fix: pip install nltk")
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
@@ -131,8 +163,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
|
||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||
raise ValueError("Cannot use device map for quantized models in training.")
|
||||
|
||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
@@ -151,21 +183,33 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed.")
|
||||
if (
|
||||
finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
):
|
||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
|
||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
if model_args.visual_inputs and data_args.packing:
|
||||
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and model_args.quantization_bit is None
|
||||
and model_args.resize_vocab
|
||||
and finetuning_args.additional_target is None
|
||||
):
|
||||
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
|
||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
@@ -235,6 +279,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
model_args.device_map = {"": get_current_device()}
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||
|
||||
@@ -266,18 +311,25 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
if finetuning_args.stage != "sft":
|
||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("vLLM engine does not support quantization.")
|
||||
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
|
||||
|
||||
if model_args.rope_scaling is not None:
|
||||
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||
|
||||
model_args.device_map = "auto"
|
||||
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
||||
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None:
|
||||
model_args.device_map = {"": torch.device(model_args.export_device)}
|
||||
else:
|
||||
model_args.device_map = "auto"
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
@@ -294,6 +346,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
model_args.device_map = "auto"
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from .loader import load_model, load_tokenizer
|
||||
from .utils import find_all_linear_modules, load_valuehead_params
|
||||
from .loader import load_config, load_model, load_tokenizer
|
||||
from .utils.misc import find_all_linear_modules
|
||||
from .utils.valuehead import load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
"load_valuehead_params",
|
||||
|
||||
@@ -5,11 +5,13 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
|
||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .utils.quantization import QuantizationMethod
|
||||
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -18,7 +20,11 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
|
||||
config: "PretrainedConfig",
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
@@ -32,9 +38,12 @@ def init_adapter(
|
||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||
return model
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||
raise ValueError("You can only use lora for quantized models.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
@@ -66,6 +75,8 @@ def init_adapter(
|
||||
for name, _ in model.named_modules():
|
||||
if ".0." in name:
|
||||
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name: # MoD starts from layer 1
|
||||
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.name_module_trainable:
|
||||
@@ -79,7 +90,7 @@ def init_adapter(
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
@@ -100,6 +111,10 @@ def init_adapter(
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_unsloth:
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||
@@ -116,9 +131,15 @@ def init_adapter(
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(
|
||||
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder
|
||||
)
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
adapter_to_resume,
|
||||
is_trainable=is_trainable,
|
||||
offload_folder=model_args.offload_folder,
|
||||
)
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
@@ -129,9 +150,23 @@ def init_adapter(
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
|
||||
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
|
||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||
if (
|
||||
finetuning_args.use_dora
|
||||
and getattr(model, "quantization_method", None) is not None
|
||||
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
|
||||
):
|
||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||
|
||||
if model_args.resize_vocab and finetuning_args.additional_target is None:
|
||||
input_embeddings = model.get_input_embeddings()
|
||||
output_embeddings = model.get_output_embeddings()
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if module in [input_embeddings, output_embeddings]:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
finetuning_args.additional_target = module_names
|
||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
@@ -139,24 +174,21 @@ def init_adapter(
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
"use_rslora": finetuning_args.use_rslora,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||
from ..extras.misc import count_parameters, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
from .utils import load_valuehead_params, register_autoclass
|
||||
from .utils.misc import register_autoclass
|
||||
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .utils.unsloth import load_unsloth_pretrained_model
|
||||
from .utils.valuehead import load_valuehead_params
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -19,7 +22,17 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizerModule(TypedDict):
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
processor: Optional["ProcessorMixin"]
|
||||
|
||||
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
r"""
|
||||
Gets arguments to load config/tokenizer/model.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
||||
return {
|
||||
"trust_remote_code": True,
|
||||
@@ -29,22 +42,56 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
r"""
|
||||
Loads pretrained tokenizer. Must before load_model.
|
||||
Loads pretrained tokenizer.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
except ValueError: # try the fast one
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=True,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
return tokenizer
|
||||
|
||||
if model_args.visual_inputs:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
else:
|
||||
processor = None
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
|
||||
|
||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
r"""
|
||||
Loads model config.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
|
||||
|
||||
def load_model(
|
||||
@@ -55,45 +102,42 @@ def load_model(
|
||||
add_valuehead: bool = False,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads pretrained model. Must after load_tokenizer.
|
||||
Loads pretrained model.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
|
||||
|
||||
model = None
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
lazy_load = False
|
||||
if model_args.use_unsloth:
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lazy_load = True
|
||||
elif is_trainable:
|
||||
model = load_unsloth_pretrained_model(config, model_args)
|
||||
|
||||
unsloth_kwargs = {
|
||||
"model_name": model_args.model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model_args.use_unsloth = False
|
||||
if model is None and not lazy_load:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
|
||||
if model_args.adapter_name_or_path:
|
||||
model_args.adapter_name_or_path = None
|
||||
logger.warning("Unsloth does not support loading adapters.")
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
elif model_args.visual_inputs:
|
||||
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
if model is None:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
if not lazy_load:
|
||||
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
patch_valuehead_model(model)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import PeftModel
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||
from ..extras.packages import is_flash_attn2_available
|
||||
from ..extras.patches.llama_patch import apply_llama_patch
|
||||
from .utils import QuantizationMethod, add_z3_leaf_module
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from .utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .utils.checkpointing import prepare_model_for_training
|
||||
from .utils.embedding import resize_embedding_layer
|
||||
from .utils.longlora import configure_longlora
|
||||
from .utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .utils.quantization import configure_quantization
|
||||
from .utils.rope import configure_rope
|
||||
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
|
||||
from .utils.visual import autocast_projector_dtype
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -28,254 +27,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
"""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
data_files = model_args.export_quantization_dataset
|
||||
else:
|
||||
data_path = model_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
|
||||
samples = []
|
||||
for _ in range(model_args.export_quantization_nsamples):
|
||||
while True:
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def _configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
if model_args.flash_attn:
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention2 is not installed.")
|
||||
return
|
||||
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
setattr(config, "attn_implementation", "flash_attention_2")
|
||||
else:
|
||||
init_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
else:
|
||||
init_kwargs["attn_implementation"] = "eager"
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if model_args.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info(
|
||||
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
||||
)
|
||||
|
||||
|
||||
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.shift_attn:
|
||||
return
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
apply_llama_patch()
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
|
||||
def _configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
if quant_method == QuantizationMethod.AWQ:
|
||||
require_version("autoawq", "To fix: pip install autoawq")
|
||||
|
||||
if quant_method == QuantizationMethod.AQLM:
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
|
||||
init_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Resize token embeddings.
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
params = [model.get_input_embeddings().weight]
|
||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||
params.append(model.get_output_embeddings().weight)
|
||||
|
||||
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||
else:
|
||||
context_maybe_zero3 = nullcontext()
|
||||
|
||||
with context_maybe_zero3:
|
||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
|
||||
if len(tokenizer) > current_embedding_size:
|
||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||
logger.warning("Current model does not support resizing token embeddings.")
|
||||
return
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
with context_maybe_zero3:
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
|
||||
|
||||
def _fp32_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def _prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||
) -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if not model_args.disable_gradient_checkpointing:
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
logger.warning("Current model does not support gradient checkpointing.")
|
||||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
model.enable_input_require_grads()
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||
@@ -289,25 +40,24 @@ def patch_config(
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
add_valuehead: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
_configure_attn_implementation(config, model_args, init_kwargs)
|
||||
_configure_rope(config, model_args, is_trainable)
|
||||
_configure_longlora(config, model_args, is_trainable)
|
||||
_configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_attn_implementation(config, model_args)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
|
||||
if add_valuehead:
|
||||
configure_valuehead(config)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
logger.info("Using KV cache for faster generation.")
|
||||
|
||||
if model_args.moe_aux_loss_coef is not None:
|
||||
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
elif getattr(config, "model_type", None) == "deepseek":
|
||||
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn)
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
@@ -316,22 +66,23 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable:
|
||||
setattr(config, "output_router_logits", True)
|
||||
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||
if init_kwargs["low_cpu_mem_usage"]:
|
||||
if "device_map" not in init_kwargs:
|
||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map
|
||||
|
||||
if init_kwargs["device_map"] == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
|
||||
def patch_model(
|
||||
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
is_trainable: bool,
|
||||
add_valuehead: bool,
|
||||
) -> None:
|
||||
gen_config = model.generation_config # check and fix generation config
|
||||
if not gen_config.do_sample and (
|
||||
@@ -344,25 +95,21 @@ def patch_model(
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
if is_trainable and getattr(model.config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
if add_valuehead:
|
||||
prepare_valuehead_model(model)
|
||||
|
||||
if model_args.resize_vocab:
|
||||
_resize_embedding_layer(model, tokenizer)
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
if model_args.visual_inputs:
|
||||
autocast_projector_dtype(model, model_args)
|
||||
|
||||
if is_trainable:
|
||||
_prepare_model_for_training(model, model_args)
|
||||
prepare_model_for_training(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
add_z3_leaf_module(model, MixtralSparseMoeBlock)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "qwen2moe":
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
|
||||
if not model_args.use_unsloth:
|
||||
print_attn_implementation(model.config)
|
||||
|
||||
try:
|
||||
model.add_model_tags(["llama-factory"])
|
||||
|
||||
55
src/llmtuner/model/utils/attention.py
Normal file
55
src/llmtuner/model/utils/attention.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "off":
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_sdpa_available():
|
||||
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
||||
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||
else:
|
||||
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||
|
||||
|
||||
def print_attn_implementation(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
attn_implementation = getattr(config, "attn_implementation", None)
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
if attn_implementation == "flash_attention_2":
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
elif attn_implementation == "sdpa":
|
||||
logger.info("Using torch SDPA for faster training and inference.")
|
||||
else:
|
||||
logger.info("Using vanilla Attention implementation.")
|
||||
94
src/llmtuner/model/utils/checkpointing.py
Normal file
94
src/llmtuner/model/utils/checkpointing.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import inspect
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras.constants import LAYERNORM_NAMES
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _gradient_checkpointing_enable(
|
||||
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
r"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
|
||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||
"""
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
||||
|
||||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||
arg.requires_grad_(True)
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
self.enable_input_require_grads()
|
||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else: # have already enabled input require gradients
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||
|
||||
|
||||
def _fp32_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||
) -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if not model_args.disable_gradient_checkpointing:
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
logger.warning("Current model does not support gradient checkpointing.")
|
||||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
58
src/llmtuner/model/utils/embedding.py
Normal file
58
src/llmtuner/model/utils/embedding.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Resize token embeddings.
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
params = [model.get_input_embeddings().weight]
|
||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||
params.append(model.get_output_embeddings().weight)
|
||||
|
||||
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||
else:
|
||||
context_maybe_zero3 = nullcontext()
|
||||
|
||||
with context_maybe_zero3:
|
||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
|
||||
if len(tokenizer) > current_embedding_size:
|
||||
if getattr(model, "quantization_method", None):
|
||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||
|
||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||
raise ValueError("Current model does not support resizing embedding layers.")
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
with context_maybe_zero3:
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -7,19 +7,28 @@ from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaSdpaAttention,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from:
|
||||
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_torch_attn_forward(
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_attention_forward(
|
||||
self: "LlamaAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -39,10 +48,11 @@ def llama_torch_attn_forward(
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@@ -69,8 +79,9 @@ def llama_torch_attn_forward(
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
@@ -97,8 +108,8 @@ def llama_torch_attn_forward(
|
||||
|
||||
|
||||
# Modified from:
|
||||
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_flash_attn_forward(
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_flash_attention_2_forward(
|
||||
self: "LlamaFlashAttention2",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -117,7 +128,6 @@ def llama_flash_attn_forward(
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -134,9 +144,10 @@ def llama_flash_attn_forward(
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
@@ -192,7 +203,115 @@ def llama_flash_attn_forward(
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def apply_llama_patch() -> None:
|
||||
require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3")
|
||||
LlamaAttention.forward = llama_torch_attn_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||
# Modified from:
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_sdpa_attention_forward(
|
||||
self: "LlamaSdpaAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
|
||||
return llama_attention_forward(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
dim=2,
|
||||
)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, :groupsz]
|
||||
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
|
||||
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.shift_attn:
|
||||
return
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
_apply_llama_patch()
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
@@ -1,49 +1,18 @@
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import cached_file
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import get_logger
|
||||
from ...extras.logging import get_logger
|
||||
from .quantization import QuantizationMethod
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""
|
||||
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||
"""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
set_z3_leaf_modules(model, [module])
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora or galore.
|
||||
@@ -100,34 +69,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
||||
return module_names
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||
return {key: f.get_tensor(key) for key in f.keys()}
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
||||
|
||||
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
||||
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
||||
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
|
||||
if "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
28
src/llmtuner/model/utils/mod.py
Normal file
28
src/llmtuner/model/utils/mod.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
|
||||
from MoD import AutoMoDModelForCausalLM
|
||||
|
||||
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
|
||||
def convert_pretrained_model_to_mod(
|
||||
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> "PreTrainedModel":
|
||||
from MoD import apply_mod_to_hf
|
||||
|
||||
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
||||
raise ValueError("Current model is not supported by mixture-of-depth.")
|
||||
|
||||
model = apply_mod_to_hf(model)
|
||||
model = model.to(model_args.compute_dtype)
|
||||
return model
|
||||
53
src/llmtuner/model/utils/moe.py
Normal file
53
src/llmtuner/model/utils/moe.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
"""
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
return
|
||||
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
if getattr(model.config, "model_type", None) == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "qwen2moe":
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "dbrx":
|
||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||
|
||||
set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if model_args.moe_aux_loss_coef is not None:
|
||||
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
elif getattr(config, "model_type", None) == "deepseek":
|
||||
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
|
||||
|
||||
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
|
||||
setattr(config, "output_router_logits", is_trainable)
|
||||
146
src/llmtuner/model/utils/quantization.py
Normal file
146
src/llmtuner/model/utils/quantization.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""
|
||||
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||
"""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
"""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
data_files = model_args.export_quantization_dataset
|
||||
else:
|
||||
data_path = model_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
|
||||
samples = []
|
||||
for _ in range(model_args.export_quantization_nsamples):
|
||||
while True:
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||
|
||||
if model_args.quantization_device_map != "auto":
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
if quant_method == QuantizationMethod.AWQ:
|
||||
require_version("autoawq", "To fix: pip install autoawq")
|
||||
|
||||
if quant_method == QuantizationMethod.AQLM:
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
|
||||
init_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
47
src/llmtuner/model/utils/rope.py
Normal file
47
src/llmtuner/model/utils/rope.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if model_args.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
logger.info(
|
||||
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
|
||||
)
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info(
|
||||
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
||||
)
|
||||
88
src/llmtuner/model/utils/unsloth.py
Normal file
88
src/llmtuner/model/utils/unsloth.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length or 4096,
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
"fix_tokenizer": False,
|
||||
"trust_remote_code": True,
|
||||
"use_gradient_checkpointing": "unsloth",
|
||||
}
|
||||
|
||||
|
||||
def load_unsloth_pretrained_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""
|
||||
Optionally loads pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model = None
|
||||
model_args.use_unsloth = False
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_unsloth_peft_model(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Gets the peft model for the pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
unsloth_peft_kwargs = {
|
||||
"model": model,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"use_gradient_checkpointing": "unsloth",
|
||||
}
|
||||
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
|
||||
|
||||
def load_unsloth_peft_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads peft model with unsloth. Used in both training and inference.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
|
||||
try:
|
||||
if not is_trainable:
|
||||
unsloth_kwargs["use_gradient_checkpointing"] = False
|
||||
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
|
||||
if not is_trainable:
|
||||
FastLanguageModel.for_inference(model)
|
||||
|
||||
return model
|
||||
59
src/llmtuner/model/utils/valuehead.py
Normal file
59
src/llmtuner/model/utils/valuehead.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import torch
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_valuehead(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "model_type", None) == "llava":
|
||||
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||
return {key: f.get_tensor(key) for key in f.keys()}
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
||||
|
||||
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
||||
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
||||
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
|
||||
if getattr(model.config, "model_type", None) == "llava":
|
||||
setattr(model, "lm_head", model.language_model.get_output_embeddings())
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
28
src/llmtuner/model/utils/visual.py
Normal file
28
src/llmtuner/model/utils/visual.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def autocast_projector_dtype(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||
) -> None:
|
||||
def _mm_projector_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
if hasattr(model, mm_projector_name):
|
||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -63,6 +64,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
|
||||
|
||||
@@ -24,8 +24,9 @@ def run_dpo(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -44,6 +45,10 @@ class CustomORPOTrainer(DPOTrainer):
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -24,8 +24,9 @@ def run_orpo(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -124,6 +125,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
|
||||
@@ -27,8 +27,9 @@ def run_ppo(
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import Trainer
|
||||
@@ -23,6 +24,10 @@ class CustomTrainer(Trainer):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -25,8 +25,9 @@ def run_pt(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -28,6 +29,10 @@ class PairwiseTrainer(Trainer):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -25,8 +25,9 @@ def run_rm(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||
@@ -33,10 +32,6 @@ class ComputeMetrics:
|
||||
r"""
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
require_version("jieba", "To fix: pip install jieba")
|
||||
require_version("nltk", "To fix: pip install nltk")
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -28,6 +29,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -28,8 +28,9 @@ def run_sft(
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
@@ -47,6 +48,7 @@ def run_sft(
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
|
||||
@@ -52,7 +52,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||
raise ValueError("Please merge adapters before quantizing the model.")
|
||||
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
||||
|
||||
@@ -65,8 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||
setattr(model.config, "torch_dtype", output_dtype)
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(output_dtype)
|
||||
model = model.to(output_dtype)
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
|
||||
@@ -5,7 +5,6 @@ from transformers import Trainer
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
@@ -57,9 +56,14 @@ def create_modelcard_and_push(
|
||||
kwargs = {
|
||||
"tasks": "text-generation",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory", finetuning_args.finetuning_type],
|
||||
}
|
||||
if data_args.dataset is not None:
|
||||
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
|
||||
|
||||
if model_args.use_unsloth:
|
||||
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
|
||||
|
||||
if not training_args.do_train:
|
||||
pass
|
||||
elif training_args.push_to_hub:
|
||||
@@ -87,7 +91,7 @@ def create_ref_model(
|
||||
)
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
tokenizer = load_tokenizer(ref_model_args)
|
||||
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
|
||||
ref_model = load_model(
|
||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
@@ -96,7 +100,7 @@ def create_ref_model(
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
ref_model = load_model(
|
||||
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
@@ -143,7 +147,7 @@ def create_reward_model(
|
||||
)
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
tokenizer = load_tokenizer(reward_model_args)
|
||||
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
|
||||
reward_model = load_model(
|
||||
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
@@ -166,8 +170,6 @@ def _create_galore_optimizer(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||
galore_targets = find_all_linear_modules(model)
|
||||
else:
|
||||
@@ -217,7 +219,7 @@ def _create_galore_optimizer(
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param])]
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in decay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||
@@ -237,7 +239,7 @@ def _create_galore_optimizer(
|
||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||
else:
|
||||
param_groups = [
|
||||
dict(params=nodecay_params),
|
||||
dict(params=nodecay_params, weight_decay=0.0),
|
||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
|
||||
]
|
||||
@@ -252,11 +254,9 @@ def _create_loraplus_optimizer(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("You should use LoRA tuning to activate LoRA+.")
|
||||
|
||||
default_lr = training_args.learning_rate
|
||||
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
|
||||
decay_args = {"weight_decay": training_args.weight_decay}
|
||||
embedding_lr = finetuning_args.loraplus_lr_embedding
|
||||
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
|
||||
@@ -279,16 +279,76 @@ def _create_loraplus_optimizer(
|
||||
|
||||
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
param_groups = [
|
||||
dict(params=param_dict["lora_a"], **decay_args),
|
||||
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
|
||||
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr),
|
||||
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
|
||||
dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
|
||||
dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
|
||||
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
|
||||
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
||||
return optimizer
|
||||
|
||||
|
||||
def _create_badam_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
decay_params, nodecay_params = [], []
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if name in decay_param_names:
|
||||
decay_params.append(param)
|
||||
else:
|
||||
nodecay_params.append(param)
|
||||
|
||||
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
param_groups = [
|
||||
dict(params=nodecay_params, weight_decay=0.0),
|
||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
|
||||
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
optimizer = BlockOptimizer(
|
||||
base_optimizer=base_optimizer,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
block_prefix_list=None,
|
||||
switch_block_every=finetuning_args.badam_switch_block_every,
|
||||
start_block=finetuning_args.badam_start_block,
|
||||
switch_mode=finetuning_args.badam_switch_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
|
||||
f"default start block is {finetuning_args.badam_start_block}"
|
||||
)
|
||||
|
||||
elif finetuning_args.badam_mode == "ratio":
|
||||
from badam import BlockOptimizerRatio
|
||||
|
||||
assert finetuning_args.badam_update_ratio > 1e-6
|
||||
optimizer = BlockOptimizerRatio(
|
||||
param_groups=param_groups,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
update_ratio=finetuning_args.badam_update_ratio,
|
||||
mask_mode=finetuning_args.badam_mask_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
include_embedding=False,
|
||||
**optim_kwargs,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
@@ -300,6 +360,9 @@ def create_custom_optimzer(
|
||||
if finetuning_args.loraplus_lr_ratio is not None:
|
||||
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
return _create_badam_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
|
||||
def create_custom_scheduler(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
@@ -314,13 +377,12 @@ def create_custom_scheduler(
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
||||
num_training_steps=num_training_steps * 2,
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
def scheduler_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
scheduler_dict[param].step()
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
@@ -17,6 +17,10 @@ if TYPE_CHECKING:
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
@@ -29,13 +33,16 @@ class WebChatModel(ChatModel):
|
||||
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
|
||||
model_name_or_path = os.environ.get("DEMO_MODEL")
|
||||
template = os.environ.get("DEMO_TEMPLATE")
|
||||
super().__init__(dict(model_name_or_path=model_name_or_path, template=template))
|
||||
infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
|
||||
super().__init__(
|
||||
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
|
||||
)
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self.engine is not None
|
||||
|
||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
def load_model(self, data) -> Generator[str, None, None]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
error = ""
|
||||
@@ -70,8 +77,9 @@ class WebChatModel(ChatModel):
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
flash_attn=(get("top.booster") == "flash_attn"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
)
|
||||
@@ -79,7 +87,7 @@ class WebChatModel(ChatModel):
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
def unload_model(self, data) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
@@ -107,6 +115,7 @@ class WebChatModel(ChatModel):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
image: Optional[NDArray],
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
@@ -114,7 +123,7 @@ class WebChatModel(ChatModel):
|
||||
chatbot[-1][1] = ""
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
|
||||
@@ -3,13 +3,13 @@ import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import gradio as gr
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
|
||||
from ..extras.constants import (
|
||||
DATA_CONFIG,
|
||||
DEFAULT_MODULE,
|
||||
DEFAULT_TEMPLATE,
|
||||
MLLM_LIST,
|
||||
PEFT_METHODS,
|
||||
STAGES_USE_PAIR_DATA,
|
||||
SUPPORTED_MODELS,
|
||||
@@ -17,6 +17,11 @@ from ..extras.constants import (
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
@@ -101,6 +106,10 @@ def get_template(model_name: str) -> str:
|
||||
return "default"
|
||||
|
||||
|
||||
def get_visual(model_name: str) -> bool:
|
||||
return get_prefix(model_name) in MLLM_LIST
|
||||
|
||||
|
||||
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type not in PEFT_METHODS:
|
||||
return gr.Dropdown(value=[], choices=[], interactive=False)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...data import Role
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -14,15 +17,21 @@ if TYPE_CHECKING:
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
|
||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(show_copy_button=True)
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
|
||||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=2)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
|
||||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=3)
|
||||
|
||||
with gr.Column() as image_box:
|
||||
image = gr.Image(sources=["upload"], type="numpy")
|
||||
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
@@ -40,19 +49,21 @@ def create_chat_box(
|
||||
[chatbot, messages, query],
|
||||
).then(
|
||||
engine.chatter.stream,
|
||||
[chatbot, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
)
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
||||
|
||||
return (
|
||||
chat_box,
|
||||
chatbot,
|
||||
messages,
|
||||
dict(
|
||||
chat_box=chat_box,
|
||||
role=role,
|
||||
system=system,
|
||||
tools=tools,
|
||||
image_box=image_box,
|
||||
image=image,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
from ...extras.constants import DATA_CONFIG
|
||||
from ...extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -29,28 +32,38 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
except Exception:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
|
||||
return gr.Button(interactive=True)
|
||||
else:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def _load_data_file(file_path: str) -> List[Any]:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
if file_path.endswith(".json"):
|
||||
return json.load(f)
|
||||
elif file_path.endswith(".jsonl"):
|
||||
return [json.loads(line) for line in f]
|
||||
else:
|
||||
return list(f)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f] # noqa: C416
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path):
|
||||
data = _load_data_file(data_path)
|
||||
else:
|
||||
data = []
|
||||
for file_name in os.listdir(data_path):
|
||||
data.extend(_load_data_file(os.path.join(data_path, file_name)))
|
||||
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -18,7 +21,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -23,9 +27,11 @@ def save_model(
|
||||
adapter_path: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
visual_inputs: bool,
|
||||
export_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_dataset: str,
|
||||
export_device: str,
|
||||
export_legacy_format: bool,
|
||||
export_dir: str,
|
||||
export_hub_model_id: str,
|
||||
@@ -41,6 +47,8 @@ def save_model(
|
||||
error = ALERTS["err_no_dataset"][lang]
|
||||
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
|
||||
error = ALERTS["err_no_adapter"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and adapter_path:
|
||||
error = ALERTS["err_gptq_lora"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
@@ -59,24 +67,28 @@ def save_model(
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
visual_inputs=visual_inputs,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id or None,
|
||||
export_size=max_shard_size,
|
||||
export_size=export_size,
|
||||
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_device=export_device,
|
||||
export_legacy_format=export_legacy_format,
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args)
|
||||
torch_gc()
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
|
||||
export_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
||||
export_legacy_format = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
@@ -95,9 +107,11 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
max_shard_size,
|
||||
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||
export_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
export_device,
|
||||
export_legacy_format,
|
||||
export_dir,
|
||||
export_hub_model_id,
|
||||
@@ -106,9 +120,10 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
return dict(
|
||||
max_shard_size=max_shard_size,
|
||||
export_size=export_size,
|
||||
export_quantization_bit=export_quantization_bit,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_device=export_device,
|
||||
export_legacy_format=export_legacy_format,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id,
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -25,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems.update({infer_backend})
|
||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
|
||||
chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(chat_elems)
|
||||
|
||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
|
||||
)
|
||||
|
||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||
lambda: ([], []), outputs=[chatbot, messages]
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||
|
||||
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
||||
lambda enabled: gr.Column(visible=enabled),
|
||||
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
||||
[chat_elems["image_box"]],
|
||||
)
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...data import templates
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ..common import get_model_path, get_template, list_adapters, save_config
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
|
||||
from ..utils import can_quantize
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -27,14 +30,17 @@ def create_top() -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
|
||||
).then(get_template, [model_name], [template], queue=False).then(
|
||||
get_visual, [model_name], [visual_inputs], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
@@ -56,4 +62,5 @@ def create_top() -> Dict[str, "Component"]:
|
||||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster,
|
||||
visual_inputs=visual_inputs,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -23,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
@@ -134,7 +138,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from .chatter import WebChatModel
|
||||
from .common import get_model_path, list_dataset, load_config
|
||||
@@ -10,6 +8,10 @@ from .runner import Runner
|
||||
from .utils import get_time
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
|
||||
self.demo_mode = demo_mode
|
||||
@@ -29,7 +31,7 @@ class Engine:
|
||||
|
||||
return output_dict
|
||||
|
||||
def resume(self) -> Generator[Dict[Component, Component], None, None]:
|
||||
def resume(self):
|
||||
user_config = load_config() if not self.demo_mode else {}
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
@@ -41,6 +43,7 @@ class Engine:
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||
init_dict["infer.image_box"] = {"visible": False}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
@@ -55,7 +58,7 @@ class Engine:
|
||||
else:
|
||||
yield self._update_component({"eval.resume_btn": {"value": True}})
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Component]:
|
||||
def change_lang(self, lang: str):
|
||||
return {
|
||||
elem: elem.__class__(**LOCALES[elem_name][lang])
|
||||
for elem_name, elem in self.manager.get_elem_iter()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import gradio as gr
|
||||
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import save_config
|
||||
from .components import (
|
||||
create_chat_box,
|
||||
@@ -13,6 +12,10 @@ from .css import CSS
|
||||
from .engine import Engine
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
|
||||
@@ -55,8 +58,8 @@ def create_web_demo() -> gr.Blocks:
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
engine.manager.add_elems("top", dict(lang=lang))
|
||||
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
|
||||
_, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.add_elems("infer", chat_elems)
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
|
||||
@@ -129,6 +129,17 @@ LOCALES = {
|
||||
"label": "加速方式",
|
||||
},
|
||||
},
|
||||
"visual_inputs": {
|
||||
"en": {
|
||||
"label": "Visual inputs",
|
||||
},
|
||||
"ru": {
|
||||
"label": "визуальные входы",
|
||||
},
|
||||
"zh": {
|
||||
"label": "图像输入",
|
||||
},
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
@@ -1073,6 +1084,17 @@ LOCALES = {
|
||||
"placeholder": "工具列表(非必填)",
|
||||
},
|
||||
},
|
||||
"image": {
|
||||
"en": {
|
||||
"label": "Image (optional)",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Изображение (по желанию)",
|
||||
},
|
||||
"zh": {
|
||||
"label": "图像(非必填)",
|
||||
},
|
||||
},
|
||||
"query": {
|
||||
"en": {
|
||||
"placeholder": "Input...",
|
||||
@@ -1150,7 +1172,7 @@ LOCALES = {
|
||||
"value": "清空历史",
|
||||
},
|
||||
},
|
||||
"max_shard_size": {
|
||||
"export_size": {
|
||||
"en": {
|
||||
"label": "Max shard size (GB)",
|
||||
"info": "The maximum size for a model file.",
|
||||
@@ -1192,6 +1214,20 @@ LOCALES = {
|
||||
"info": "量化过程中使用的校准数据集。",
|
||||
},
|
||||
},
|
||||
"export_device": {
|
||||
"en": {
|
||||
"label": "Export device",
|
||||
"info": "Which device should be used to export model.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Экспорт устройство",
|
||||
"info": "Какое устройство следует использовать для экспорта модели.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出设备",
|
||||
"info": "导出模型使用的设备类型。",
|
||||
},
|
||||
},
|
||||
"export_legacy_format": {
|
||||
"en": {
|
||||
"label": "Export legacy format",
|
||||
@@ -1287,7 +1323,12 @@ ALERTS = {
|
||||
"err_no_export_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"ru": "Пожалуйста, укажите каталог для экспорта.",
|
||||
"zh": "请填写导出目录",
|
||||
"zh": "请填写导出目录。",
|
||||
},
|
||||
"err_gptq_lora": {
|
||||
"en": "Please merge adapters before quantizing the model.",
|
||||
"ru": "Пожалуйста, объедините адаптеры перед квантованием модели.",
|
||||
"zh": "量化模型前请先合并适配器。",
|
||||
},
|
||||
"err_failed": {
|
||||
"en": "Failed.",
|
||||
|
||||
@@ -60,4 +60,5 @@ class Manager:
|
||||
self._id_to_elem["top.template"],
|
||||
self._id_to_elem["top.rope_scaling"],
|
||||
self._id_to_elem["top.booster"],
|
||||
self._id_to_elem["top.visual_inputs"],
|
||||
}
|
||||
|
||||
@@ -4,9 +4,7 @@ import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator
|
||||
|
||||
import gradio as gr
|
||||
import transformers
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
@@ -14,13 +12,20 @@ from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
@@ -62,7 +67,7 @@ class Runner:
|
||||
if not model_path:
|
||||
return ALERTS["err_no_path"][lang]
|
||||
|
||||
if len(dataset) == 0:
|
||||
if not dataset:
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
if not from_preview and self.demo_mode:
|
||||
@@ -117,8 +122,9 @@ class Runner:
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn=(get("top.booster") == "flashattn"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
@@ -217,8 +223,9 @@ class Runner:
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn=(get("top.booster") == "flashattn"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
@@ -239,7 +246,7 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
@@ -249,7 +256,7 @@ class Runner:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield {output_box: gen_cmd(args)}
|
||||
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
@@ -263,19 +270,19 @@ class Runner:
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||
def preview_train(self, data):
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||
def preview_eval(self, data):
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||
def run_train(self, data):
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||
def run_eval(self, data):
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
|
||||
def monitor(self):
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
@@ -332,7 +339,7 @@ class Runner:
|
||||
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
|
||||
def save_args(self, data):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
@@ -351,7 +358,7 @@ class Runner:
|
||||
save_path = save_args(config_path, config_dict)
|
||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||
|
||||
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
||||
def load_args(self, lang: str, config_path: str):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
config_dict = load_args(config_path)
|
||||
if config_dict is None:
|
||||
|
||||
@@ -3,21 +3,24 @@ import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ..extras.packages import is_matplotlib_available
|
||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
|
||||
|
||||
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
|
||||
if not callback.max_steps:
|
||||
return gr.Slider(visible=False)
|
||||
|
||||
Reference in New Issue
Block a user