76 Commits

Author SHA1 Message Date
hiyouga
3cef844079 fix setup
Former-commit-id: 7d3e7db46a5f8672dd57fa5fcc03822e175047f9
2024-04-28 03:49:13 +08:00
hiyouga
4dcd47100d fix llava rlhf
Former-commit-id: f6863cbbcbf960d6481296c6cae3e40fd70e4e14
2024-04-28 03:01:49 +08:00
hiyouga
a412b4ed4a add models to 0.7.0
Former-commit-id: 436d3754452f839c617839ab3bbaacc4a8908e19
2024-04-28 01:50:30 +08:00
hiyouga
544a6259b6 update readme
Former-commit-id: c9190fe36f511c3a5149d45c85a10b02a57fa88a
2024-04-26 23:39:19 +08:00
hiyouga
c501f377dd release v0.7.0
Former-commit-id: 45bb89cb4d26a6b3fb5360bc90ab950738fe4920
2024-04-26 23:18:00 +08:00
hiyouga
cb8b8f40cd update readme
Former-commit-id: f3d4b46338d4d484b205d0651a1fa7b2e77a1654
2024-04-26 20:09:14 +08:00
hiyouga
70bed8ad8f support Qwen1.5 110B
Former-commit-id: d6e5ecaf4109127bab24e39a0696076bceb0b37c
2024-04-26 19:59:22 +08:00
hiyouga
51f776ae2a fix llava qlora
Former-commit-id: 01c5a669f6fe598aac1758a700a7607da37db1bc
2024-04-26 18:00:23 +08:00
hiyouga
697bc20941 add llava to llamaboard
Former-commit-id: deaaff0a9de0eef9691991c99cd797461b1165cc
2024-04-26 06:41:35 +08:00
hiyouga
1480e3a88f update readme
Former-commit-id: df1155245d3f71ba4f3361d43aa662ab3b024de8
2024-04-26 05:49:26 +08:00
hoshi-hiyouga
19029d5b0f Merge pull request #3454 from hiyouga/mllm
Support fine-tuning LLaVA-1.5 MLLM @BUAADreamer 

Former-commit-id: c4195d1e26349795f7aad5c10a8a9e2abb7b64a3
2024-04-26 05:46:29 +08:00
hiyouga
7773ac0ead update readme
Former-commit-id: 41728fd74de7bec0cc6135aef9dfa3ae9fe7af73
2024-04-26 05:44:30 +08:00
hiyouga
23b881bff1 support mllm hf inference
Former-commit-id: 2c7c01282acd7ddabbb17ce3246b8dae4bc4b8cf
2024-04-26 05:34:58 +08:00
hoshi-hiyouga
10a6c395bb Merge pull request #3450 from BUAADreamer/mllm
Add Multimodal LLM Finetuning

Former-commit-id: 7cacbcfdf7391080ef43eb2b2c79a5237e6120e8
2024-04-26 05:30:30 +08:00
hoshi-hiyouga
f9a7732a1f Update preprocess.py
Former-commit-id: 0e376eab23d38b8fca05f054f3cde308756ee3b1
2024-04-26 04:10:28 +08:00
hoshi-hiyouga
c37582af02 Update aligner.py
Former-commit-id: 855489074c469f47572153df0fa1e251b187b232
2024-04-26 03:48:34 +08:00
hoshi-hiyouga
ece67f8c7f Update parser.py
Former-commit-id: 4df75e8a9a391565cc3eec69bc0ebf5d5192de61
2024-04-26 03:35:39 +08:00
hoshi-hiyouga
e1838e76fe Update loader.py
Former-commit-id: 6a5f2e2ab7304113ff71cb77aafff6a1f74831f8
2024-04-26 03:33:07 +08:00
hoshi-hiyouga
2eede9ffd6 Update workflow.py
Former-commit-id: 5b8b5b975716d539ae2fae8536f79e106aa0b566
2024-04-26 03:29:12 +08:00
hoshi-hiyouga
a6f6b406b3 Update loader.py
Former-commit-id: 72d4817a15f6916706828ea2a61d808183c23773
2024-04-26 03:22:40 +08:00
hoshi-hiyouga
279439abbe update hparam name
Former-commit-id: 9941adfbf06db37f8ba32c4555f6e58e27188aaf
2024-04-26 02:49:39 +08:00
hoshi-hiyouga
13117b69d7 delete llava template (use vicuna)
Former-commit-id: 420e64970e5a0e45453041927e0366ee8beb73d5
2024-04-26 02:20:47 +08:00
BUAADreamer
5d03ac642d modify some bug
Former-commit-id: 593b7b004df74bd24361c9883401a656c08fb589
2024-04-25 22:59:46 +08:00
BUAADreamer
5062ee547e modify some style
Former-commit-id: 1291c7ee39361dd75247c67f04dcf20b472faf83
2024-04-25 22:40:53 +08:00
BUAADreamer
59817c27e3 modify some style
Former-commit-id: d578a90cefa7ec813355795bdd6ead5ee558ce26
2024-04-25 22:40:25 +08:00
BUAADreamer
759bee48d2 merge some func
Former-commit-id: 3085107c44715e4b2ca96d73b20d90c172b95219
2024-04-25 22:35:17 +08:00
BUAADreamer
514ffafc12 modify some style
Former-commit-id: 053062abc007014a7fde95c5ae9f4d859893d8ad
2024-04-25 22:04:09 +08:00
BUAADreamer
8b2a735c14 modify some style
Former-commit-id: b016e6a671a2f228f0bdd9b8d5995b4669609655
2024-04-25 21:58:18 +08:00
BUAADreamer
10d59e9e4a make dataset script
Former-commit-id: 25892f958da14976025a775febf628cd0e0a3d85
2024-04-25 21:32:01 +08:00
BUAADreamer
058ed5e607 modify style
Former-commit-id: c1f1df99e4dc3d0aadf1207b4e9a16218187fd5a
2024-04-25 21:29:50 +08:00
BUAADreamer
110c2ce2a5 modify style
Former-commit-id: 3bffc1e1b8bcc4582cebea06d35e5146163c7bec
2024-04-25 21:27:48 +08:00
BUAADreamer
c425436676 modify style
Former-commit-id: 54b713d0c4ffdfc6a7faeb14471b58bb1cd8acf5
2024-04-25 21:15:16 +08:00
BUAADreamer
266fe908e3 Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: c4bb5af69c5bbf0b1ea044cbb2b18acddc6733ac
2024-04-25 21:08:40 +08:00
BUAADreamer
dbd905438b add some
Former-commit-id: 8d035a849c4a441d457791aab073861adf69a09f
2024-04-25 21:08:32 +08:00
hoshi-hiyouga
d64c87f928 Merge pull request #3449 from hiyouga/mllm
add webui backend option

Former-commit-id: 372fcedef40b79fe8bd3932c06c720f2a03db6e6
2024-04-25 20:58:16 +08:00
hiyouga
29eebef696 add webui backend option
Former-commit-id: 3764586cb3ed64fe376d0ae420ff5690c28459e2
2024-04-25 20:49:23 +08:00
hiyouga
7bfbcb1fe3 vllm + lora support
Former-commit-id: 8cb86ba355195f5d6dcb95ee6b6b7203463a34db
2024-04-25 20:24:31 +08:00
BUAADreamer
9b210cf4b3 rm some
Former-commit-id: 2c85b4fabbebd8b51eee53f5d29184d4a6e97569
2024-04-25 20:09:43 +08:00
BUAADreamer
f74e640565 Merge branch 'hiyouga:main' into main
Former-commit-id: 131d0bcd554dedd794add7eb3d7b1201cac80e7c
2024-04-25 20:02:50 +08:00
BUAADreamer
d1d08d066a merge data part to the text stream
Former-commit-id: 80537d580119d9d5a06ab236a5284aaae2f83b5b
2024-04-25 19:58:47 +08:00
hiyouga
6be321b5da fix #3374
Former-commit-id: 0097d7968b3b570e1705caff26f42d9ed71ad974
2024-04-25 19:56:49 +08:00
BUAADreamer
3c792174db merge data part to the text stream
Former-commit-id: 7ee20286d9bcc2d5378bfd6bb02cd3648396d873
2024-04-25 19:19:59 +08:00
hiyouga
9aeb88c426 add export_device in webui #3333
Former-commit-id: 30ebd3652809d73941e0a5e4a8be11d989faf98d
2024-04-25 19:02:32 +08:00
BUAADreamer
00e2a272ef merge model part to the text stream
Former-commit-id: b6fcb832ddaed4647d6f2b926f3dfccd47f3ea84
2024-04-25 08:20:41 +08:00
BUAADreamer
5142349661 remove error
Former-commit-id: 2bcd1c7dc3595f17ae4e2c4475196cc2d03d0e75
2024-04-25 01:01:59 +08:00
BUAADreamer
0e3cc52327 remove conflicts
Former-commit-id: e5750ee202eb67cf5fc54f464548e2eb43d00900
2024-04-25 00:56:06 +08:00
BUAADreamer
6c1db2d012 remove conflicts
Former-commit-id: f8b637eb76cba7ec229e2978068805ad1cca8adb
2024-04-25 00:34:22 +08:00
BUAADreamer
12c51655ce add llava and instructblip
Former-commit-id: 142fb6f4541a1acfefe66ff2574dabde53b00c06
2024-04-25 00:22:43 +08:00
hiyouga
36be12a3b7 update tool template
Former-commit-id: c72a1981859818c257c5271d32e03c9d3c344206
2024-04-25 00:21:34 +08:00
hiyouga
21fac4c98c fix log level
Former-commit-id: 8d21302f6201b3f33c10f61f3559bd95be3363c2
2024-04-24 23:42:59 +08:00
hiyouga
83404c4fa9 support new special token #3420
Former-commit-id: f5c6a47f5193ab3a6c137580992bdcce0b31fdd5
2024-04-24 23:39:31 +08:00
hoshi-hiyouga
12f852b8d4 fix phi template
Former-commit-id: 14a1ff665eaebfc618229efbe96f09848d52faec
2024-04-24 13:55:14 +08:00
hoshi-hiyouga
a88873116a fix webchatmodel
Former-commit-id: dc6d8b5dc42c363dd180aaf90c9a2f2d0cce6725
2024-04-24 13:54:21 +08:00
hoshi-hiyouga
7cfcd69c64 fix inference in llamaboard
Former-commit-id: 5e631915157083b61e2d5a183e0c91f2d11f416e
2024-04-24 13:53:39 +08:00
hiyouga
a5eabbe933 add olmo 1.7
Former-commit-id: 86a3fb3a141d2702b15af08df36ffcf9b3d6de14
2024-04-24 05:50:50 +08:00
hiyouga
aa25716a5d add dbrx and jamba models
Former-commit-id: ce35c80b4b00152185285d6064939803d14487f0
2024-04-24 05:39:52 +08:00
hiyouga
94c8219575 fix bug
Former-commit-id: 38e164fe4aaea6f0baf121a720291ca42643ba8c
2024-04-24 05:21:18 +08:00
hiyouga
ad24a2a0c9 fix bug
Former-commit-id: 271c24d2c82d645fa9072e6de94ca38f20411537
2024-04-24 05:10:07 +08:00
hiyouga
c05027d14a remove redundant code
Former-commit-id: 4a7a7ad2bcdc493458084f5f3d384239228b7d5a
2024-04-24 05:02:18 +08:00
hiyouga
5420905a2e support unsloth generate
Former-commit-id: 0ef1ad9f505dba71db9342f524cc3a7565e5e09e
2024-04-24 04:46:53 +08:00
hiyouga
03f2e3284a refactor patcher
Former-commit-id: 263cfe1294f5c3188f5e8d65791f35ee0d87315a
2024-04-24 03:02:23 +08:00
hiyouga
d2bb1b3a6b reenable sdpa and fast tok by default
Former-commit-id: 9e00902dbedc71d55743d1bf237843506a557891
2024-04-24 02:18:44 +08:00
hiyouga
35c4a2c212 fix #3347 #3387
Former-commit-id: c253c18185a29b59190f3e0ed236c2bb4c788085
2024-04-24 01:30:16 +08:00
hiyouga
1e4010a1fb support phi-3
Former-commit-id: 7e8ffa9beee3893e051ceeade443bd56c4a07b1c
2024-04-24 00:28:53 +08:00
BUAADreamer
1451297c78 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: 67800c565b086f362b8cf131b0c9babaa7a7ebc7
2024-04-23 19:22:42 +08:00
BUAADreamer
0b99b13786 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: b78b5f290aa38a7454e101ee9703fb6fac5064ac
2024-04-23 18:47:03 +08:00
BUAADreamer
f5edbf2b49 Merge branch 'hiyouga:main' into main
Former-commit-id: 6287d1b789c631205c1033adf036e28deaef4167
2024-04-23 18:46:12 +08:00
BUAADreamer
ab6dc0ea30 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: a730f89a972f1a9d37c718c716f199cb8d4903b2
2024-04-23 18:45:43 +08:00
hiyouga
79d34ce0f3 update examples
Former-commit-id: 8bf55682cdfbbdca0f01073eac0084c20a6a09d1
2024-04-23 18:29:46 +08:00
hiyouga
1d2e372a8e update readme
Former-commit-id: d4eaee262a64e716ce475dc4eb18d8d9697d8dd8
2024-04-22 17:09:17 +08:00
hiyouga
f6a53d83c8 update readme
Former-commit-id: 3eab580703ee01a0d2d75e7f01df5165af551386
2024-04-22 00:51:35 +08:00
hiyouga
4ec56dd958 update readme
Former-commit-id: fdca136309709e43d75a831252b9375a5a99635a
2024-04-22 00:42:25 +08:00
hiyouga
ba06eb65ca update readme and examples
Former-commit-id: 27dd9bf201c24f7804811398bc2758966ec78432
2024-04-22 00:37:32 +08:00
hiyouga
be716972fe remove extras
Former-commit-id: d67e972f8c3d5273e589c8c85c0a1620f59785c5
2024-04-22 00:35:41 +08:00
hiyouga
719585a128 update readme
Former-commit-id: 3a8c17907c71f46b1b37501e2afdc99ad89fb4bc
2024-04-22 00:21:01 +08:00
hiyouga
348f29aa50 set dev version
Former-commit-id: b9557887d7506ff57b2b2bf490092aac4e4becf0
2024-04-21 23:14:30 +08:00
78 changed files with 1869 additions and 795 deletions

113
README.md
View File

@@ -43,8 +43,8 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
@@ -68,9 +68,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage. [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See `examples/lora_single_gpu/sft_mllm.sh` for usage.
[24/04/19] We supported **Meta Llama 3** model series. [24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage. [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
@@ -110,7 +112,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. [23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
@@ -134,34 +136,38 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Default module | Template | | Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE] > [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules. > **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence.
> >
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models. > For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
>
> Remember to use the **SAME** template in training and inference.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported. Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
@@ -232,6 +238,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2) - [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -251,7 +258,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
@@ -286,15 +293,15 @@ huggingface-cli login
\* *estimated* \* *estimated*
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B | 8x22B | | Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ | | ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB | | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB | | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## Getting Started ## Getting Started
@@ -315,7 +322,7 @@ cd LLaMA-Factory
pip install -e .[metrics] pip install -e .[metrics]
``` ```
Extra dependencies available: deepspeed, metrics, unsloth, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality Extra dependencies available: deepspeed, metrics, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
<details><summary>For Windows users</summary> <details><summary>For Windows users</summary>
@@ -329,7 +336,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
</details> </details>
### LLaMA Board GUI ### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
> [!IMPORTANT] > [!IMPORTANT]
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training. > LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
@@ -342,6 +349,16 @@ export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
python src/train_web.py # or python -m llmtuner.webui.interface python src/train_web.py # or python -m llmtuner.webui.interface
``` ```
<details><summary>For Alibaba Cloud users</summary>
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
```bash
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
</details>
#### Use Docker #### Use Docker
```bash ```bash
@@ -371,7 +388,7 @@ docker compose -f ./docker-compose.yml up -d
</details> </details>
### Command Line Interface ### Train with Command Line Interface
See [examples/README.md](examples/README.md) for usage. See [examples/README.md](examples/README.md) for usage.
@@ -381,13 +398,13 @@ Use `python src/train_bash.py -h` to display arguments description.
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template mistral \ --template llama3 \
--infer_backend vllm \ --infer_backend vllm \
--vllm_enforce_eager --vllm_enforce_eager
``` ```
### Use ModelScope Hub ### Download from ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
@@ -395,7 +412,7 @@ If you have trouble with downloading models and datasets from Hugging Face, you
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
``` ```
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `modelscope/Llama-2-7b-ms`. Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
## Projects using LLaMA Factory ## Projects using LLaMA Factory
@@ -444,7 +461,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -11,7 +11,7 @@
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
👋 加入我们的[微信群](assets/wechat.jpg)。 👋 加入我们的[微信群](assets/wechat.jpg)。
@@ -23,7 +23,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
选择你的打开方式: 选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
## 目录 ## 目录
@@ -43,8 +43,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色 ## 项目特色
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
@@ -68,9 +68,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod` [24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 `examples/lora_single_gpu/sft_mllm.sh`
[24/04/19] 我们支持了 **Meta Llama 3** 系列模型 [24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam` [24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`
@@ -110,7 +112,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2。 [23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn fa2` 参数以启用 FlashAttention-2。
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 [23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
@@ -134,34 +136,38 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型 ## 模型
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE] > [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。 > **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以得到更好的效果
> >
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat模型请务必使用**对应的模板**。
>
> 请务必在训练和推理时使用**完全一致**的模板。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。 项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
@@ -232,6 +238,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2) - [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -251,7 +258,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
@@ -286,15 +293,15 @@ huggingface-cli login
\* *估算值* \* *估算值*
| 方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B | 8x22B | | 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ | | ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB | | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB | | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## 如何使用 ## 如何使用
@@ -315,7 +322,7 @@ cd LLaMA-Factory
pip install -e .[metrics] pip install -e .[metrics]
``` ```
可选的额外依赖项deepspeed、metrics、unsloth、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality 可选的额外依赖项deepspeed、metrics、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
<details><summary>Windows 用户指南</summary> <details><summary>Windows 用户指南</summary>
@@ -329,10 +336,10 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
</details> </details>
### LLaMA Board 可视化界面 ### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
> [!IMPORTANT] > [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行分布式训练。 > LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。
#### 使用本地环境 #### 使用本地环境
@@ -342,6 +349,16 @@ export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
python src/train_web.py # 或 python -m llmtuner.webui.interface python src/train_web.py # 或 python -m llmtuner.webui.interface
``` ```
<details><summary>阿里云用户指南</summary>
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
```bash
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
</details>
#### 使用 Docker #### 使用 Docker
```bash ```bash
@@ -371,23 +388,23 @@ docker compose -f ./docker-compose.yml up -d
</details> </details>
### 命令行接口 ### 利用命令行接口训练
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。 使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
使用 `python src/train_bash.py -h` 查看参数文档。 您可以执行 `python src/train_bash.py -h` 查看参数文档。
### 使OpenAI 风格 API 和 vLLM 部署 ### 用 vLLM 部署 OpenAI API
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template mistral \ --template llama3 \
--infer_backend vllm \ --infer_backend vllm \
--vllm_enforce_eager --vllm_enforce_eager
``` ```
### 使用魔搭社区 ### 魔搭社区下载
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
@@ -395,7 +412,7 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
``` ```
`--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `modelscope/Llama-2-7b-ms` `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
@@ -444,7 +461,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

View File

@@ -18,7 +18,8 @@ If you are using a custom dataset, please provide your dataset definition in the
"history": "the column name in the dataset containing the histories. (default: None)", "history": "the column name in the dataset containing the histories. (default: None)",
"messages": "the column name in the dataset containing the messages. (default: conversations)", "messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)" "tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
}, },
"tags (optional, used for the sharegpt format)": { "tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)", "role_tag": "the key in the message represents the identity. (default: from)",

View File

@@ -18,7 +18,8 @@
"history": "数据集代表历史对话的表头名称默认None", "history": "数据集代表历史对话的表头名称默认None",
"messages": "数据集代表消息列表的表头名称默认conversations", "messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None", "system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None" "tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
}, },
"tags可选用于 sharegpt 格式)": { "tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from", "role_tag": "消息中代表发送者身份的键名默认from",

View File

@@ -9,6 +9,7 @@ examples/
│ ├── ppo.sh: Do PPO training using LoRA │ ├── ppo.sh: Do PPO training using LoRA
│ ├── dpo.sh: Do DPO training using LoRA │ ├── dpo.sh: Do DPO training using LoRA
│ ├── orpo.sh: Do ORPO training using LoRA │ ├── orpo.sh: Do ORPO training using LoRA
│ ├── sft_mllm.sh: Do supervised fine-tuning on multimodal data using LoRA
│ ├── prepare.sh: Save tokenized dataset │ ├── prepare.sh: Save tokenized dataset
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning │ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
├── qlora_single_gpu/ ├── qlora_single_gpu/
@@ -18,18 +19,19 @@ examples/
│ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA │ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
├── lora_multi_gpu/ ├── lora_multi_gpu/
│ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA │ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA ── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
│ └── ds_zero3.sh: Fine-tune model with DeepSpeed ZeRO-3 using LoRA (weight sharding)
├── full_multi_gpu/ ├── full_multi_gpu/
│ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node │ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
│ ├── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes │ ├── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after full tuning │ └── predict.sh: Do parallel batch predict and compute BLEU and ROUGE scores after full tuning
├── merge_lora/ ├── merge_lora/
│ ├── merge.sh: Merge LoRA weights into the pre-trained models │ ├── merge.sh: Merge LoRA weights into the pre-trained models
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ │ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
├── inference/ ├── inference/
│ ├── cli_demo.sh: Launch a command line interface with LoRA adapters │ ├── cli_demo.sh: Chat with fine-tuned model in the CLI with LoRA adapters
│ ├── api_demo.sh: Launch an OpenAI-style API with LoRA adapters │ ├── api_demo.sh: Chat with fine-tuned model in an OpenAI-style API with LoRA adapters
│ ├── web_demo.sh: Launch a web interface with LoRA adapters │ ├── web_demo.sh: Chat with fine-tuned model in the Web browser with LoRA adapters
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters │ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
└── extras/ └── extras/
├── galore/ ├── galore/

View File

@@ -9,6 +9,7 @@ examples/
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练 │ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练 │ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练 │ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
│ ├── sft_mllm.sh: 基于 LoRA 进行多模态指令监督微调
│ ├── prepare.sh: 保存预处理后的数据集 │ ├── prepare.sh: 保存预处理后的数据集
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数 │ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
├── qlora_single_gpu/ ├── qlora_single_gpu/
@@ -18,11 +19,12 @@ examples/
│ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型 │ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
├── lora_multi_gpu/ ├── lora_multi_gpu/
│ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练 │ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练 ── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
│ └── ds_zero3.sh: 使用 DeepSpeed ZeRO-3 进行 LoRA 训练(拆分权重)
├── full_multi_gpu/ ├── full_multi_gpu/
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练 │ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
│ ├── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练 │ ├── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
│ └── predict.sh: 基于全量训练进行批量预测并计算 BLEU 和 ROUGE 分数 │ └── predict.sh: 基于全量训练进行多卡批量预测并计算 BLEU 和 ROUGE 分数
├── merge_lora/ ├── merge_lora/
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中 │ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
│ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型 │ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型

View File

@@ -9,7 +9,7 @@ main_process_port: 29555
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
num_machines: 2 # the number of nodes num_machines: 2 # the number of nodes
num_processes: 16 # the number of GPUs in all nodes num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@@ -9,7 +9,7 @@ main_process_port: 29555
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
num_machines: 2 # the number of nodes num_machines: 2 # the number of nodes
num_processes: 16 # the number of GPUs in all nodes num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA
pip install "transformers>=4.39.1" pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0" pip install "accelerate>=0.28.0"

View File

@@ -1,6 +1,8 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \
--stage sft \ --stage sft \
--do_predict \ --do_predict \
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \ --model_name_or_path ../../saves/LLaMA2-7B/full/sft \

View File

@@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
# add `--visual_inputs True` to load MLLM
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@@ -0,0 +1,33 @@
#!/bin/bash
deepspeed --num_gpus 4 ../../src/train_bash.py \
--deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--ddp_timeout 180000000 \
--plot_loss \
--fp16

View File

@@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
# also launch it on slave machine using slave_config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/master_config.yaml \ --config_file ../accelerate/master_config.yaml \

View File

@@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \ --config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \ ../../src/train_bash.py \
--stage sft \ --stage sft \

View File

@@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--visual_inputs \
--dataset mllm_demo \
--dataset_dir ../../data \
--template vicuna \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft_mllm \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 100.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -8,4 +8,5 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--finetuning_type lora \ --finetuning_type lora \
--export_dir ../../models/llama2-7b-sft \ --export_dir ../../models/llama2-7b-sft \
--export_size 2 \ --export_size 2 \
--export_device cpu \
--export_legacy_format False --export_legacy_format False

View File

@@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
# NEED TO run `merge.sh` before using this script
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path ../../models/llama2-7b-sft \ --model_name_or_path ../../models/llama2-7b-sft \

View File

@@ -15,3 +15,4 @@ fastapi
sse-starlette sse-starlette
matplotlib matplotlib
fire fire
packaging

View File

@@ -44,8 +44,9 @@ def calculate_lr(
overwrite_cache=True, overwrite_cache=True,
) )
) )
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage) tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft": elif stage == "sft":

View File

@@ -32,8 +32,8 @@ def length_cdf(
overwrite_cache=True, overwrite_cache=True,
) )
) )
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
total_num = len(trainset) total_num = len(trainset)
length_dict = defaultdict(int) length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]): for sample in tqdm(trainset["input_ids"]):

View File

@@ -22,10 +22,9 @@ def get_requires():
extra_require = { extra_require = {
"deepspeed": ["deepspeed>=0.10.0"], "deepspeed": ["deepspeed>=0.10.0"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam"], "badam": ["badam"],
"vllm": ["vllm>=0.3.3"], "vllm": ["vllm>=0.4.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],

View File

@@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo from .webui import create_ui, create_web_demo
__version__ = "0.6.3" __version__ = "0.7.0"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
@@ -46,6 +47,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ... ) -> List["Response"]: ...
@@ -55,6 +57,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ... ) -> AsyncGenerator[str, None]: ...

View File

@@ -8,6 +8,8 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -36,9 +38,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
return task.result() return task.result()
async def achat( async def achat(
@@ -46,18 +49,20 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
return await self.engine.chat(messages, system, tools, **input_kwargs) return await self.engine.chat(messages, system, tools, image, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, **input_kwargs) generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -70,9 +75,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(

View File

@@ -14,7 +14,9 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
@@ -30,7 +32,9 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
self.tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model( self.model = load_model(
@@ -42,13 +46,18 @@ class HuggingfaceEngine(BaseEngine):
def _process_args( def _process_args(
model: "PreTrainedModel", model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
@@ -95,6 +104,11 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@staticmethod @staticmethod
@@ -102,15 +116,17 @@ class HuggingfaceEngine(BaseEngine):
def _chat( def _chat(
model: "PreTrainedModel", model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@@ -135,15 +151,17 @@ class HuggingfaceEngine(BaseEngine):
def _stream_chat( def _stream_chat(
model: "PreTrainedModel", model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@@ -199,6 +217,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@@ -208,11 +227,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
self.processor,
self.template, self.template,
self.generating_args, self.generating_args,
messages, messages,
system, system,
tools, tools,
image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self._semaphore:
@@ -224,6 +245,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@@ -233,11 +255,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
self.processor,
self.template, self.template,
self.generating_args, self.generating_args,
messages, messages,
system, system,
tools, tools,
image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self._semaphore:

View File

@@ -2,16 +2,23 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available
from ..model import load_tokenizer from ..model import load_config, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -23,31 +30,59 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
config = load_config(model_args) # may download model from ms hub
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
infer_dtype = str(infer_dtype).split(".")[-1]
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs( tokenizer_module = load_tokenizer(model_args)
model=model_args.model_name_or_path, self.tokenizer = tokenizer_module["tokenizer"]
trust_remote_code=True, self.processor = tokenizer_module["processor"]
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count() or 1,
gpu_memory_utilization=model_args.vllm_gpu_util,
disable_log_stats=True,
disable_log_requests=True,
enforce_eager=model_args.vllm_enforce_eager,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
self.tokenizer = load_tokenizer(model_args)
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if model_args.visual_inputs:
# TODO: auto derive from config
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
self.image_feature_size = 576
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_input_shape"] = "1,3,336,336"
engine_args["image_feature_size"] = self.image_feature_size
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
self.lora_request = None
async def _generate( async def _generate(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
@@ -91,8 +126,21 @@ class VllmEngine(BaseEngine):
max_tokens=generating_args["max_new_tokens"], max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True, skip_special_tokens=True,
) )
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
) )
return result_generator return result_generator
@@ -104,10 +152,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, **input_kwargs) generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@@ -129,10 +178,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, **input_kwargs) generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text

View File

@@ -1,3 +1,4 @@
import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
@@ -13,8 +14,23 @@ if TYPE_CHECKING:
from .parser import DatasetAttr from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []} outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
return outputs
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])): for i in range(len(examples[dataset_attr.prompt])):
prompt = [] prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
@@ -44,12 +60,16 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("") outputs["tools"].append("")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def convert_sharegpt(
outputs = {"prompt": [], "response": [], "system": [], "tools": []} examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER.value, dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value, dataset_attr.assistant_tag: Role.ASSISTANT.value,
@@ -84,6 +104,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
outputs["response"].append(aligned_messages[-1:]) outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system) outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs return outputs
@@ -96,12 +117,13 @@ def align_dataset(
prompt: [{"role": "user", "content": "..."}] * (2T - 1) prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." system: "..."
tools: "..." tools: "...",
images: [],
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else: else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict( features = Features.from_dict(
@@ -114,6 +136,7 @@ def align_dataset(
], ],
"system": {"dtype": "string", "_type": "Value"}, "system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"}, "tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
} }
) )
kwargs = {} kwargs = {}

View File

@@ -1,6 +1,6 @@
import inspect import inspect
import os import os
from typing import TYPE_CHECKING, Literal, Union from typing import TYPE_CHECKING, Literal, Optional, Union
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
@@ -16,7 +16,7 @@ from .utils import checksum, merge_dataset
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments from ..hparams import DataArguments, ModelArguments
@@ -115,11 +115,12 @@ def load_single_dataset(
def get_dataset( def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template) template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
@@ -149,7 +150,7 @@ def get_dataset(
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func( preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage data_args, training_args, stage, template, tokenizer, processor
) )
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}

View File

@@ -28,6 +28,7 @@ class DatasetAttr:
formatting: Literal["alpaca", "sharegpt"] = "alpaca" formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """ """ columns """
system: Optional[str] = None system: Optional[str] = None
images: Optional[str] = None
""" columns for the alpaca format """ """ columns for the alpaca format """
prompt: Optional[str] = "instruction" prompt: Optional[str] = "instruction"
query: Optional[str] = "input" query: Optional[str] = "input"
@@ -105,7 +106,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system"] column_names = ["system", "images"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"]) column_names.extend(["prompt", "query", "response", "history"])
else: else:

View File

@@ -1,14 +1,22 @@
from functools import partial from functools import partial
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_pillow_available
from .utils import Role from .utils import Role
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments from ..hparams import DataArguments
@@ -18,6 +26,13 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0]
def preprocess_pretrain_dataset( def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
@@ -48,18 +63,25 @@ def preprocess_pretrain_dataset(
def preprocess_supervised_dataset( def preprocess_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], [] input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate( for turn_idx, (source_ids, target_ids) in enumerate(
@@ -89,14 +111,16 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs return model_inputs
def preprocess_packed_supervised_dataset( def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
@@ -141,17 +165,24 @@ def preprocess_packed_supervised_dataset(
def preprocess_unsupervised_dataset( def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1: if len(examples["prompt"][i]) % 2 != 1:
continue continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1: if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
else: else:
@@ -172,22 +203,32 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs return model_inputs
def preprocess_pairwise_dataset( def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn( prompt_ids, chosen_ids = template.encode_oneturn(
@@ -214,6 +255,8 @@ def preprocess_pairwise_dataset(
model_inputs["prompt_ids"].append(prompt_ids) model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids) model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids) model_inputs["rejected_ids"].append(rejected_ids)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs return model_inputs
@@ -244,34 +287,54 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]: ) -> Tuple[Callable, Callable]:
if stage == "pt": if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing: if data_args.packing:
preprocess_func = partial( preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
) )
else: else:
preprocess_func = partial( preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
) )
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm": elif stage == "rm":
preprocess_func = partial( preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
) )
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else: else:
preprocess_func = partial( preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)

View File

@@ -503,6 +503,7 @@ _register_template(
name="chatml", name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"], stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True, replace_eos=True,
@@ -513,6 +514,7 @@ _register_template(
name="chatml_de", name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"], stop_words=["<|im_end|>", "<|im_start|>"],
@@ -550,6 +552,32 @@ _register_template(
) )
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)\nThis is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template( _register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@@ -608,6 +636,9 @@ _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
efficient_eos=True, efficient_eos=True,
force_system=True, force_system=True,
@@ -678,6 +709,14 @@ _register_template(
format_system=StringFormatter( format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
), ),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
@@ -718,10 +757,23 @@ _register_template(
) )
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
@@ -818,7 +870,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate", default_system="You are Zephyr, a helpful assistant.",
) )

View File

@@ -78,9 +78,9 @@ def split_dataset(
if training_args.do_train: if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
val_set = dataset.take(int(data_args.val_size)) val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size)) train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set} return {"train_dataset": train_set, "eval_dataset": val_set}
else: else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size

View File

@@ -21,7 +21,7 @@ from .template import get_eval_template
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args) self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)

View File

@@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
MLLM_LIST = ["LLaVA1.5"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"] MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"] PEFT_METHODS = ["lora"]
@@ -47,6 +49,8 @@ TRAINING_STAGES = {
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
@@ -266,6 +270,22 @@ register_model_group(
) )
register_model_group(
models={
"DBRX-132B-Base": {
DownloadSource.DEFAULT: "databricks/dbrx-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
},
"DBRX-132B-Chat": {
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
},
},
module="Wqkv",
template="dbrx",
)
register_model_group( register_model_group(
models={ models={
"DeepSeek-LLM-7B-Base": { "DeepSeek-LLM-7B-Base": {
@@ -286,9 +306,11 @@ register_model_group(
}, },
"DeepSeek-Math-7B-Base": { "DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
}, },
"DeepSeek-Math-7B-Chat": { "DeepSeek-Math-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
}, },
"DeepSeek-MoE-16B-Base": { "DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
@@ -451,6 +473,16 @@ register_model_group(
) )
register_model_group(
models={
"Jambda-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
},
)
register_model_group( register_model_group(
models={ models={
"LingoWhale-8B": { "LingoWhale-8B": {
@@ -538,6 +570,19 @@ register_model_group(
) )
register_model_group(
models={
"LLaVA1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
},
"LLaVA1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
)
register_model_group( register_model_group(
models={ models={
"Mistral-7B-v0.1": { "Mistral-7B-v0.1": {
@@ -573,6 +618,7 @@ register_model_group(
}, },
"Mixtral-8x22B-v0.1": { "Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
}, },
"Mixtral-8x22B-v0.1-Chat": { "Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
@@ -585,18 +631,15 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"OLMo-1B": { "OLMo-1B": {
DownloadSource.DEFAULT: "allenai/OLMo-1B", DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
}, },
"OLMo-7B": { "OLMo-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B", DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
}, },
"OLMo-7B-Chat": { "OLMo-1.7-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct", DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
}, },
}, },
module="att_proj",
template="olmo",
) )
@@ -604,7 +647,7 @@ register_model_group(
models={ models={
"OpenChat3.5-7B-Chat": { "OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106", DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5", DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
} }
}, },
template="openchat", template="openchat",
@@ -652,6 +695,22 @@ register_model_group(
) )
register_model_group(
models={
"Phi3-3.8B-4k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-4k-instruct",
},
"Phi3-3.8B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-128k-instruct",
},
},
module="qkv_proj",
template="phi",
)
register_model_group( register_model_group(
models={ models={
"Qwen-1.8B": { "Qwen-1.8B": {
@@ -754,6 +813,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
}, },
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": { "Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
@@ -790,6 +853,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
}, },
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": { "Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
@@ -850,6 +917,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
}, },
"Qwen1.5-110B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-int4-Chat": { "Qwen1.5-MoE-A2.7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
@@ -891,12 +962,15 @@ register_model_group(
models={ models={
"StarCoder2-3B": { "StarCoder2-3B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-3b", DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
}, },
"StarCoder2-7B": { "StarCoder2-7B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-7b", DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
}, },
"StarCoder2-15B": { "StarCoder2-15B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-15b", DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
}, },
} }
) )
@@ -919,17 +993,53 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"XuanYuan-6B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
},
"XuanYuan-70B": { "XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
},
"XuanYuan-2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
},
"XuanYuan-6B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
}, },
"XuanYuan-70B-Chat": { "XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan-2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
},
"XuanYuan-6B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
},
"XuanYuan-6B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
}, },
"XuanYuan-70B-int8-Chat": { "XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
}, },
"XuanYuan-70B-int4-Chat": { "XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
"XuanYuan-2-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
},
"XuanYuan-2-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
}, },
}, },
template="xuanyuan", template="xuanyuan",
@@ -966,6 +1076,30 @@ register_model_group(
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat", DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
}, },
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
}, },
template="xverse", template="xverse",
) )
@@ -1058,6 +1192,9 @@ register_model_group(
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta", DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta", DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
}, },
"Zephyr-141B-ORPO-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
},
}, },
template="zephyr", template="zephyr",
) )

View File

@@ -1,16 +1,23 @@
import importlib.metadata import importlib.metadata
import importlib.util import importlib.util
from typing import TYPE_CHECKING
from packaging import version
if TYPE_CHECKING:
from packaging.version import Version
def _is_package_available(name: str) -> bool: def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None return importlib.util.find_spec(name) is not None
def _get_package_version(name: str) -> str: def _get_package_version(name: str) -> "Version":
try: try:
return importlib.metadata.version(name) return version.parse(importlib.metadata.version(name))
except Exception: except Exception:
return "0.0.0" return version.parse("0.0.0")
def is_fastapi_availble(): def is_fastapi_availble():
@@ -18,7 +25,7 @@ def is_fastapi_availble():
def is_flash_attn2_available(): def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2") return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
def is_galore_available(): def is_galore_available():
@@ -41,6 +48,10 @@ def is_nltk_available():
return _is_package_available("nltk") return _is_package_available("nltk")
def is_pillow_available():
return _is_package_available("PIL")
def is_requests_available(): def is_requests_available():
return _is_package_available("requests") return _is_package_available("requests")
@@ -49,6 +60,10 @@ def is_rouge_available():
return _is_package_available("rouge_chinese") return _is_package_available("rouge_chinese")
def is_sdpa_available():
return _get_package_version("torch") > version.parse("2.1.1")
def is_starlette_available(): def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")

View File

@@ -26,11 +26,11 @@ class DataArguments:
) )
cutoff_len: int = field( cutoff_len: int = field(
default=1024, default=1024,
metadata={"help": "The cutoff length of the model inputs after tokenization."}, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
) )
reserved_label_len: int = field( reserved_label_len: int = field(
default=1, default=1,
metadata={"help": "The minimum cutoff length reserved for label after tokenization."}, metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
) )
train_on_prompt: bool = field( train_on_prompt: bool = field(
default=False, default=False,

View File

@@ -31,11 +31,11 @@ class GeneratingArguments:
metadata={"help": "Number of beams for beam search. 1 means no beam search."}, metadata={"help": "Number of beams for beam search. 1 means no beam search."},
) )
max_length: int = field( max_length: int = field(
default=512, default=1024,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
) )
max_new_tokens: int = field( max_new_tokens: int = field(
default=512, default=1024,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
) )
repetition_penalty: float = field( repetition_penalty: float = field(

View File

@@ -22,7 +22,7 @@ class ModelArguments:
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
use_fast_tokenizer: bool = field( use_fast_tokenizer: bool = field(
default=False, default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
) )
resize_vocab: bool = field( resize_vocab: bool = field(
@@ -33,6 +33,10 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer."},
)
model_revision: str = field( model_revision: str = field(
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
@@ -61,9 +65,9 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
flash_attn: bool = field( flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
default=False, default="auto",
metadata={"help": "Enable FlashAttention for faster training."}, metadata={"help": "Enable FlashAttention for faster training and inference."},
) )
shift_attn: bool = field( shift_attn: bool = field(
default=False, default=False,
@@ -77,6 +81,10 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
) )
visual_inputs: bool = field(
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: Optional[float] = field(
default=None, default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
@@ -135,7 +143,7 @@ class ModelArguments:
) )
export_device: str = field( export_device: str = field(
default="cpu", default="cpu",
metadata={"help": "The device used in model export."}, metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
) )
export_quantization_bit: Optional[int] = field( export_quantization_bit: Optional[int] = field(
default=None, default=None,
@@ -174,9 +182,15 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."

View File

@@ -67,6 +67,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter: if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.") raise ValueError("Cannot create new adapter upon a quantized model.")
@@ -86,7 +89,7 @@ def _check_extra_dependencies(
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3") require_version("vllm>=0.4.0", "To fix: pip install vllm>=0.4.0")
if finetuning_args.use_galore: if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch") require_version("galore_torch", "To fix: pip install galore_torch")
@@ -94,6 +97,9 @@ def _check_extra_dependencies(
if finetuning_args.use_badam: if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam") require_version("badam", "To fix: pip install badam")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
if training_args is not None and training_args.predict_with_generate: if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba") require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk") require_version("nltk", "To fix: pip install nltk")
@@ -190,16 +196,20 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
if ( if (
training_args.do_train training_args.do_train
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
and model_args.quantization_bit is None
and model_args.resize_vocab and model_args.resize_vocab
and finetuning_args.additional_target is None and finetuning_args.additional_target is None
): ):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.") logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.") logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
@@ -301,15 +311,18 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if finetuning_args.stage != "sft": if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.") raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.adapter_name_or_path is not None:
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support quantization.") raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
if model_args.rope_scaling is not None: if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.") raise ValueError("vLLM engine does not support RoPE scaling.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)

View File

@@ -1,8 +1,10 @@
from .loader import load_model, load_tokenizer from .loader import load_config, load_model, load_tokenizer
from .utils import find_all_linear_modules, load_valuehead_params from .utils.misc import find_all_linear_modules
from .utils.valuehead import load_valuehead_params
__all__ = [ __all__ = [
"load_config",
"load_model", "load_model",
"load_tokenizer", "load_tokenizer",
"load_valuehead_params", "load_valuehead_params",

View File

@@ -5,11 +5,13 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
@@ -18,7 +20,11 @@ logger = get_logger(__name__)
def init_adapter( def init_adapter(
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Initializes the adapters. Initializes the adapters.
@@ -105,6 +111,10 @@ def init_adapter(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1] adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1] adapter_to_resume = model_args.adapter_name_or_path[-1]
@@ -121,9 +131,15 @@ def init_adapter(
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained( if model_args.use_unsloth:
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
) else:
model = PeftModel.from_pretrained(
model,
adapter_to_resume,
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
if is_trainable and adapter_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -141,6 +157,17 @@ def init_adapter(
): ):
raise ValueError("DoRA is not compatible with PTQ-quantized models.") raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
"target_modules": target_modules, "target_modules": target_modules,
@@ -151,14 +178,7 @@ def init_adapter(
} }
if model_args.use_unsloth: if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore model = get_unsloth_peft_model(model, model_args, peft_kwargs)
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else: else:
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,

View File

@@ -1,18 +1,20 @@
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.constants import MOD_SUPPORTED_MODELS
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass from .utils.misc import register_autoclass
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .utils.unsloth import load_unsloth_pretrained_model
from .utils.valuehead import load_valuehead_params
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
@@ -20,7 +22,17 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
model_args.model_name_or_path = try_download_model_from_ms(model_args) model_args.model_name_or_path = try_download_model_from_ms(model_args)
return { return {
"trust_remote_code": True, "trust_remote_code": True,
@@ -30,9 +42,9 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
} }
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r""" r"""
Loads pretrained tokenizer. Must before load_model. Loads pretrained tokenizer.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
@@ -53,8 +65,33 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
**init_kwargs, **init_kwargs,
) )
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
return tokenizer
if model_args.visual_inputs:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
else:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model( def load_model(
@@ -65,61 +102,39 @@ def load_model(
add_valuehead: bool = False, add_valuehead: bool = False,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Loads pretrained model. Must after load_tokenizer. Loads pretrained model.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
model = None model = None
if is_trainable and model_args.use_unsloth: lazy_load = False
from unsloth import FastLanguageModel # type: ignore if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
unsloth_kwargs = { if model is None and not lazy_load:
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
}
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
if model_args.adapter_name_or_path:
model_args.adapter_name_or_path = None
logger.warning("Unsloth does not support loading adapters.")
if model is None:
init_kwargs["config"] = config init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
from MoD import AutoMoDModelForCausalLM model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
else: else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs) model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert": if model_args.mixture_of_depths == "convert":
from MoD import apply_mod_to_hf model = convert_pretrained_model_to_mod(model, config, model_args)
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: if not lazy_load:
raise ValueError("Current model is not supported by mixture-of-depth.") patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = apply_mod_to_hf(model) model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
model = model.to(model_args.compute_dtype)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if add_valuehead: if add_valuehead:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

View File

@@ -1,23 +1,22 @@
import math
import os
import random
from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Tuple from typing import TYPE_CHECKING, Any, Dict
import torch import torch
from datasets import load_dataset
from peft import PeftModel from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_flash_attn2_available from .utils.attention import configure_attn_implementation, print_attn_implementation
from ..extras.patches.llama_patch import apply_llama_patch from .utils.checkpointing import prepare_model_for_training
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable from .utils.embedding import resize_embedding_layer
from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
from .utils.visual import autocast_projector_dtype
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -28,255 +27,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
return
logger.info("Using FlashAttention-2 for faster training and inference.")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", "flash_attention_2")
else:
setattr(config, "_attn_implementation", "flash_attention_2")
else:
setattr(config, "_attn_implementation", "eager")
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
@@ -290,25 +40,24 @@ def patch_config(
model_args: "ModelArguments", model_args: "ModelArguments",
init_kwargs: Dict[str, Any], init_kwargs: Dict[str, Any],
is_trainable: bool, is_trainable: bool,
add_valuehead: bool,
) -> None: ) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
_configure_attn_implementation(config, model_args) configure_attn_implementation(config, model_args)
_configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
if add_valuehead:
configure_valuehead(config)
if model_args.use_cache and not is_trainable: if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.") logger.info("Using KV cache for faster generation.")
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn) setattr(config, "use_flash_attn", model_args.flash_attn)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
@@ -317,9 +66,6 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable:
setattr(config, "output_router_logits", True)
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
@@ -332,7 +78,11 @@ def patch_config(
def patch_model( def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
add_valuehead: bool,
) -> None: ) -> None:
gen_config = model.generation_config # check and fix generation config gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and ( if not gen_config.do_sample and (
@@ -345,25 +95,21 @@ def patch_model(
if "GenerationMixin" not in str(model.generate.__func__): if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)
if is_trainable and getattr(model.config, "model_type", None) == "chatglm": if add_valuehead:
setattr(model, "lm_head", model.transformer.output_layer) prepare_valuehead_model(model)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if model_args.resize_vocab: if model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer) resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable: if is_trainable:
_prepare_model_for_training(model, model_args) prepare_model_for_training(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) == "mixtral": if not model_args.use_unsloth:
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock print_attn_implementation(model.config)
add_z3_leaf_module(model, MixtralSparseMoeBlock)
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
try: try:
model.add_model_tags(["llama-factory"]) model.add_model_tags(["llama-factory"])

View File

@@ -1,175 +0,0 @@
import inspect
from enum import Enum, unique
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from transformers import PreTrainedModel
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import cached_file
from transformers.utils.versions import require_version
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import ModelArguments
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, [module])
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora or galore.
"""
quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None:
linear_cls = torch.nn.Linear
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
elif model.config.model_type == "internlm2":
output_layer_names.append("output")
module_names = set()
for name, module in model.named_modules():
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")
if num_layers % num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
trainable_layer in name for trainable_layer in trainable_layers
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names
def gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

View File

@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla Attention implementation.")

View File

@@ -0,0 +1,94 @@
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@@ -0,0 +1,58 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))

View File

@@ -1,5 +1,5 @@
import math import math
from typing import Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -7,19 +7,28 @@ from transformers.models.llama.modeling_llama import (
Cache, Cache,
LlamaAttention, LlamaAttention,
LlamaFlashAttention2, LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.utils import logging from transformers.utils import logging
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Modified from: # Modified from:
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_torch_attn_forward( def llama_attention_forward(
self: "LlamaAttention", self: "LlamaAttention",
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
@@ -39,10 +48,11 @@ def llama_torch_attn_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -69,8 +79,9 @@ def llama_torch_attn_forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: if attention_mask is not None: # no matter the length, we just slice it
attn_weights = attn_weights + attention_mask causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -97,8 +108,8 @@ def llama_torch_attn_forward(
# Modified from: # Modified from:
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn_forward( def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2", self: "LlamaFlashAttention2",
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
@@ -117,7 +128,6 @@ def llama_flash_attn_forward(
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -134,9 +144,10 @@ def llama_flash_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) query_states = query_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0 dropout_rate = self.attention_dropout if self.training else 0.0
@@ -192,7 +203,115 @@ def llama_flash_attn_forward(
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def apply_llama_patch() -> None: # Modified from:
require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3") # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
LlamaAttention.forward = llama_torch_attn_forward def llama_sdpa_attention_forward(
LlamaFlashAttention2.forward = llama_flash_attn_forward self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
return llama_attention_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, :groupsz]
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _apply_llama_patch() -> None:
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")

View File

@@ -0,0 +1,78 @@
from typing import TYPE_CHECKING, List
import torch
from ...extras.logging import get_logger
from .quantization import QuantizationMethod
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora or galore.
"""
quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None:
linear_cls = torch.nn.Linear
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
elif model.config.model_type == "internlm2":
output_layer_names.append("output")
module_names = set()
for name, module in model.named_modules():
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")
if num_layers % num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
trainable_layer in name for trainable_layer in trainable_layers
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
from MoD import AutoMoDModelForCausalLM
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
def convert_pretrained_model_to_mod(
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
) -> "PreTrainedModel":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
return model

View File

@@ -0,0 +1,53 @@
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled():
return
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)

View File

@@ -0,0 +1,146 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ...hparams import ModelArguments
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

View File

@@ -0,0 +1,47 @@
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)

View File

@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"use_gradient_checkpointing": "unsloth",
}
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False
return model
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
if not is_trainable:
FastLanguageModel.for_inference(model)
return model

View File

@@ -0,0 +1,59 @@
from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_valuehead(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
if getattr(model.config, "model_type", None) == "llava":
setattr(model, "lm_head", model.language_model.get_output_embeddings())
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Tuple
import torch
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)

View File

@@ -24,8 +24,9 @@ def run_dpo(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(

View File

@@ -24,8 +24,9 @@ def run_orpo(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(

View File

@@ -27,8 +27,9 @@ def run_ppo(
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training

View File

@@ -25,8 +25,9 @@ def run_pt(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

View File

@@ -25,8 +25,9 @@ def run_rm(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

View File

@@ -28,8 +28,9 @@ def run_sft(
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if training_args.predict_with_generate: if training_args.predict_with_generate:
@@ -47,6 +48,7 @@ def run_sft(
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
# Initialize our Trainer # Initialize our Trainer
trainer = CustomSeq2SeqTrainer( trainer = CustomSeq2SeqTrainer(

View File

@@ -52,7 +52,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.") raise ValueError("Please merge adapters before quantizing the model.")
tokenizer = load_tokenizer(model_args) tokenizer = load_tokenizer(model_args)["tokenizer"]
get_template_and_fix_tokenizer(tokenizer, data_args.template) get_template_and_fix_tokenizer(tokenizer, data_args.template)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab

View File

@@ -61,6 +61,9 @@ def create_modelcard_and_push(
if data_args.dataset is not None: if data_args.dataset is not None:
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")] kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
if model_args.use_unsloth:
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
if not training_args.do_train: if not training_args.do_train:
pass pass
elif training_args.push_to_hub: elif training_args.push_to_hub:
@@ -88,7 +91,7 @@ def create_ref_model(
) )
ref_model_args = ModelArguments(**ref_model_args_dict) ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora") ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
tokenizer = load_tokenizer(ref_model_args) tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model( ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
@@ -97,7 +100,7 @@ def create_ref_model(
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
ref_model = None ref_model = None
else: else:
tokenizer = load_tokenizer(model_args) tokenizer = load_tokenizer(model_args)["tokenizer"]
ref_model = load_model( ref_model = load_model(
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
@@ -144,7 +147,7 @@ def create_reward_model(
) )
reward_model_args = ModelArguments(**reward_model_args_dict) reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora") reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
tokenizer = load_tokenizer(reward_model_args) tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model( reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
) )

View File

@@ -2,6 +2,8 @@ import json
import os import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
@@ -31,7 +33,10 @@ class WebChatModel(ChatModel):
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
model_name_or_path = os.environ.get("DEMO_MODEL") model_name_or_path = os.environ.get("DEMO_MODEL")
template = os.environ.get("DEMO_TEMPLATE") template = os.environ.get("DEMO_TEMPLATE")
super().__init__(dict(model_name_or_path=model_name_or_path, template=template)) infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
super().__init__(
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
)
@property @property
def loaded(self) -> bool: def loaded(self) -> bool:
@@ -72,8 +77,9 @@ class WebChatModel(ChatModel):
finetuning_type=get("top.finetuning_type"), finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
flash_attn=(get("top.booster") == "flash_attn"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"), use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"), infer_backend=get("infer.infer_backend"),
) )
@@ -109,6 +115,7 @@ class WebChatModel(ChatModel):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: str, system: str,
tools: str, tools: str,
image: Optional[NDArray],
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float, temperature: float,
@@ -116,7 +123,7 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = "" chatbot[-1][1] = ""
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
if tools: if tools:

View File

@@ -9,6 +9,7 @@ from ..extras.constants import (
DATA_CONFIG, DATA_CONFIG,
DEFAULT_MODULE, DEFAULT_MODULE,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
MLLM_LIST,
PEFT_METHODS, PEFT_METHODS,
STAGES_USE_PAIR_DATA, STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS, SUPPORTED_MODELS,
@@ -105,6 +106,10 @@ def get_template(model_name: str) -> str:
return "default" return "default"
def get_visual(model_name: str) -> bool:
return get_prefix(model_name) in MLLM_LIST
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown": def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS: if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value=[], choices=[], interactive=False) return gr.Dropdown(value=[], choices=[], interactive=False)

View File

@@ -17,15 +17,21 @@ if TYPE_CHECKING:
def create_chat_box( def create_chat_box(
engine: "Engine", visible: bool = False engine: "Engine", visible: bool = False
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]: ) -> Tuple["Component", "Component", Dict[str, "Component"]]:
with gr.Column(visible=visible) as chat_box: with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True) chatbot = gr.Chatbot(show_copy_button=True)
messages = gr.State([]) messages = gr.State([])
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value) with gr.Row():
system = gr.Textbox(show_label=False) with gr.Column():
tools = gr.Textbox(show_label=False, lines=2) role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as image_box:
image = gr.Image(sources=["upload"], type="numpy")
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@@ -43,19 +49,21 @@ def create_chat_box(
[chatbot, messages, query], [chatbot, messages, query],
).then( ).then(
engine.chatter.stream, engine.chatter.stream,
[chatbot, messages, system, tools, max_new_tokens, top_p, temperature], [chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
[chatbot, messages], [chatbot, messages],
) )
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
return ( return (
chat_box,
chatbot, chatbot,
messages, messages,
dict( dict(
chat_box=chat_box,
role=role, role=role,
system=system, system=system,
tools=tools, tools=tools,
image_box=image_box,
image=image,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,

View File

@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ...train import export_model from ...train import export_model
from ..common import get_save_dir from ..common import get_save_dir
@@ -26,9 +27,11 @@ def save_model(
adapter_path: List[str], adapter_path: List[str],
finetuning_type: str, finetuning_type: str,
template: str, template: str,
max_shard_size: int, visual_inputs: bool,
export_size: int,
export_quantization_bit: int, export_quantization_bit: int,
export_quantization_dataset: str, export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool, export_legacy_format: bool,
export_dir: str, export_dir: str,
export_hub_model_id: str, export_hub_model_id: str,
@@ -44,6 +47,8 @@ def save_model(
error = ALERTS["err_no_dataset"][lang] error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path: elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang] error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and adapter_path:
error = ALERTS["err_gptq_lora"][lang]
if error: if error:
gr.Warning(error) gr.Warning(error)
@@ -62,24 +67,28 @@ def save_model(
adapter_name_or_path=adapter_name_or_path, adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
visual_inputs=visual_inputs,
export_dir=export_dir, export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None, export_hub_model_id=export_hub_model_id or None,
export_size=max_shard_size, export_size=export_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None, export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
export_model(args) export_model(args)
torch_gc()
yield ALERTS["info_exported"][lang] yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100, step=1) export_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none") export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
export_legacy_format = gr.Checkbox() export_legacy_format = gr.Checkbox()
with gr.Row(): with gr.Row():
@@ -98,9 +107,11 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.adapter_path"), engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"), engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.template"), engine.manager.get_elem_by_id("top.template"),
max_shard_size, engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit, export_quantization_bit,
export_quantization_dataset, export_quantization_dataset,
export_device,
export_legacy_format, export_legacy_format,
export_dir, export_dir,
export_hub_model_id, export_hub_model_id,
@@ -109,9 +120,10 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
return dict( return dict(
max_shard_size=max_shard_size, export_size=export_size,
export_quantization_bit=export_quantization_bit, export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
export_dir=export_dir, export_dir=export_dir,
export_hub_model_id=export_hub_model_id, export_hub_model_id=export_hub_model_id,

View File

@@ -28,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.update({infer_backend}) input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False) chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems)) elem_dict.update(chat_elems)
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box] lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
) )
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, messages] lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]) ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
[chat_elems["image_box"]],
)
return elem_dict return elem_dict

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
from ...data import templates from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, list_adapters, save_config from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
from ..utils import can_quantize from ..utils import can_quantize
@@ -30,14 +30,17 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab: with gr.Accordion(open=False) as advanced_tab:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(templates.keys()), value="default") template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none") booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save ).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
@@ -59,4 +62,5 @@ def create_top() -> Dict[str, "Component"]:
template=template, template=template,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
booster=booster, booster=booster,
visual_inputs=visual_inputs,
) )

View File

@@ -138,7 +138,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1) lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1) lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01) lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01) loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
create_new_adapter = gr.Checkbox() create_new_adapter = gr.Checkbox()

View File

@@ -43,6 +43,7 @@ class Engine:
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())} init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())} init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())} init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
init_dict["infer.image_box"] = {"visible": False}
if user_config.get("last_model", None): if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]} init_dict["top.model_name"] = {"value": user_config["last_model"]}

View File

@@ -58,8 +58,8 @@ def create_web_demo() -> gr.Blocks:
lang = gr.Dropdown(choices=["en", "zh"]) lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.add_elems("top", dict(lang=lang)) engine.manager.add_elems("top", dict(lang=lang))
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True) _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems)) engine.manager.add_elems("infer", chat_elems)
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None) demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False) lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)

View File

@@ -129,6 +129,17 @@ LOCALES = {
"label": "加速方式", "label": "加速方式",
}, },
}, },
"visual_inputs": {
"en": {
"label": "Visual inputs",
},
"ru": {
"label": "визуальные входы",
},
"zh": {
"label": "图像输入",
},
},
"training_stage": { "training_stage": {
"en": { "en": {
"label": "Stage", "label": "Stage",
@@ -1073,6 +1084,17 @@ LOCALES = {
"placeholder": "工具列表(非必填)", "placeholder": "工具列表(非必填)",
}, },
}, },
"image": {
"en": {
"label": "Image (optional)",
},
"ru": {
"label": "Изображение (по желанию)",
},
"zh": {
"label": "图像(非必填)",
},
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input...", "placeholder": "Input...",
@@ -1150,7 +1172,7 @@ LOCALES = {
"value": "清空历史", "value": "清空历史",
}, },
}, },
"max_shard_size": { "export_size": {
"en": { "en": {
"label": "Max shard size (GB)", "label": "Max shard size (GB)",
"info": "The maximum size for a model file.", "info": "The maximum size for a model file.",
@@ -1192,6 +1214,20 @@ LOCALES = {
"info": "量化过程中使用的校准数据集。", "info": "量化过程中使用的校准数据集。",
}, },
}, },
"export_device": {
"en": {
"label": "Export device",
"info": "Which device should be used to export model.",
},
"ru": {
"label": "Экспорт устройство",
"info": "Какое устройство следует использовать для экспорта модели.",
},
"zh": {
"label": "导出设备",
"info": "导出模型使用的设备类型。",
},
},
"export_legacy_format": { "export_legacy_format": {
"en": { "en": {
"label": "Export legacy format", "label": "Export legacy format",
@@ -1287,7 +1323,12 @@ ALERTS = {
"err_no_export_dir": { "err_no_export_dir": {
"en": "Please provide export dir.", "en": "Please provide export dir.",
"ru": "Пожалуйста, укажите каталог для экспорта.", "ru": "Пожалуйста, укажите каталог для экспорта.",
"zh": "请填写导出目录", "zh": "请填写导出目录",
},
"err_gptq_lora": {
"en": "Please merge adapters before quantizing the model.",
"ru": "Пожалуйста, объедините адаптеры перед квантованием модели.",
"zh": "量化模型前请先合并适配器。",
}, },
"err_failed": { "err_failed": {
"en": "Failed.", "en": "Failed.",

View File

@@ -60,4 +60,5 @@ class Manager:
self._id_to_elem["top.template"], self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"], self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"], self._id_to_elem["top.booster"],
self._id_to_elem["top.visual_inputs"],
} }

View File

@@ -67,7 +67,7 @@ class Runner:
if not model_path: if not model_path:
return ALERTS["err_no_path"][lang] return ALERTS["err_no_path"][lang]
if len(dataset) == 0: if not dataset:
return ALERTS["err_no_dataset"][lang] return ALERTS["err_no_dataset"][lang]
if not from_preview and self.demo_mode: if not from_preview and self.demo_mode:
@@ -122,8 +122,9 @@ class Runner:
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn=(get("top.booster") == "flashattn"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"), use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"), dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")), dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"), cutoff_len=get("train.cutoff_len"),
@@ -222,8 +223,9 @@ class Runner:
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn=(get("top.booster") == "flashattn"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"), use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("eval.dataset_dir"), dataset_dir=get("eval.dataset_dir"),
dataset=",".join(get("eval.dataset")), dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"), cutoff_len=get("eval.cutoff_len"),