70 Commits

Author SHA1 Message Date
hiyouga
c8fe3f544b release v0.6.3
Former-commit-id: 947572af8de201669598f54735f35b50bb719d71
2024-04-21 23:13:23 +08:00
hiyouga
0f1ad7140f fix #3366
Former-commit-id: dc20237455c36de44f8922539d7dfadd8bedb12f
2024-04-21 21:34:25 +08:00
hiyouga
233e167f68 fix optimizers
Former-commit-id: f811eee2fa12a89a55a9c5d3a05a1521b4347727
2024-04-21 20:40:54 +08:00
hiyouga
1d341dcd83 fix #3365
Former-commit-id: 415ce41e8fa887e980e5bd575c8e95bd4076b90b
2024-04-21 19:20:18 +08:00
hiyouga
d16561e7a4 fix bug in galore optimizer
Former-commit-id: c05ac23261a5a8ba893c2918a43dc7777307407b
2024-04-21 18:53:22 +08:00
hiyouga
f8e219dc81 fix mod stuff
Former-commit-id: cf3988226e6398c67bb2955578e436fc505aa5c5
2024-04-21 18:11:10 +08:00
hoshi-hiyouga
3365cc8cf0 Merge pull request #3338 from astramind-ai/main
Adding Mixture of Depth

Former-commit-id: 4da2ece53353b63e672ff529d6beba41ff710c14
2024-04-21 18:05:52 +08:00
hoshi-hiyouga
3a5e68b7d9 fix #3348
Former-commit-id: aa5e921c00f60074eceb2f9d4d8837cc713edba6
2024-04-20 10:34:09 +08:00
hiyouga
0cb596fee1 add dpo mix dataset
Former-commit-id: 6def3f8bfa51b2d9d73af112352ce07db972e4c9
2024-04-20 01:31:38 +08:00
hiyouga
b3b5b530d1 fix #3352
Former-commit-id: f315f8e8ec916b82bac94a159e55839ff155c6b5
2024-04-19 22:40:01 +08:00
hiyouga
9225c15c88 fix llama3 template
Former-commit-id: 20e95250168fbe081c779b2e1ff23f5df3ce02f7
2024-04-19 15:46:51 +08:00
Marco
abd9fed445 fix small typo
Former-commit-id: 5638a03cd0cf8119ff366b3b3e303b5a2351b065
2024-04-18 20:33:29 +02:00
Marco
44cda2eece Added Mixture of Depths
Former-commit-id: 75dd98b9abc847e22cb263c17ebcd2ca5dd98345
2024-04-18 20:31:24 +02:00
hoshi-hiyouga
8397808d1d support llama3
Former-commit-id: c1eabb751a5fd73b710714451b146732e0ed4558
2024-04-19 01:13:50 +08:00
hiyouga
9e1bd6420d fix #3324
Former-commit-id: 5e710c4ac331f3400534d33b2646c4108c898d98
2024-04-18 15:34:45 +08:00
hiyouga
619264c854 tiny fix
Former-commit-id: 86399ca8c06273c42c2b184664ae25d3405b3bf6
2024-04-18 00:22:17 +08:00
hiyouga
1ebac62e3d update readme
Former-commit-id: a49112a74339ba77bfec53f7870e821fe148db2c
2024-04-17 23:40:49 +08:00
hiyouga
ce9bdb3509 add mixtral 8x22B models
Former-commit-id: eccbeecff0909e1fa124b5439ffbbfbc5607e1d6
2024-04-17 23:35:59 +08:00
hiyouga
0c8d6369ac add CodeQwen models
Former-commit-id: 9f6094241391f8f717818c8ba94e11d1791b4a5c
2024-04-17 23:27:22 +08:00
hiyouga
bee796f6b5 fix #3316
Former-commit-id: 7395e9e90a209228ff563ab54319955608850fc3
2024-04-17 22:54:34 +08:00
hiyouga
9f6349a333 fix #3317
Former-commit-id: 7dce1763be4374cf616d96db95ae964ff510a9d6
2024-04-17 22:17:19 +08:00
hiyouga
171a029c5e lint
Former-commit-id: 917d65ce65024d17a5030bc57083a427cfae16d7
2024-04-16 18:21:09 +08:00
hoshi-hiyouga
eaefaa0fe0 Merge pull request #3291 from codemayq/main
support for previewing custom dataset in directory format

Former-commit-id: 40d89152282101a7c08f53e72c2ad7124a0595f3
2024-04-16 18:12:09 +08:00
hiyouga
d301f0a64b Update parser.py
Former-commit-id: 92c2133896c20054db86dd53508c982e39bd5ca0
2024-04-16 18:09:31 +08:00
hiyouga
0a1578e4e3 update readme and gradio version
Former-commit-id: 4029b60ddcbd15b5354503c51178f0f5e7e9aedf
2024-04-16 18:09:16 +08:00
hiyouga
a4167fd925 support badam for all stages
Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
2024-04-16 17:44:48 +08:00
hoshi-hiyouga
42084e08ae Merge pull request #3287 from Ledzy/badam
[Feature] Add BAdam algorithm

Former-commit-id: 10a5e1e65b34b03e5ca2a41bf6ded09a3fb25f0c
2024-04-16 17:32:16 +08:00
hoshi-hiyouga
9d23f5dc89 Update utils.py
Former-commit-id: 01147536b2bb507e87e033fa696e9eb39fe96bbe
2024-04-16 17:30:12 +08:00
hoshi-hiyouga
5978427ae0 Update trainer.py
Former-commit-id: c6163be1444c00dd000f288e2f834968bd932981
2024-04-16 17:29:52 +08:00
hoshi-hiyouga
c7c216069c Update utils.py
Former-commit-id: 7edf4dbed88b8034282f14fd6e0cb6f7f9e5f805
2024-04-16 17:29:30 +08:00
hoshi-hiyouga
cde9d1b917 Update patcher.py
Former-commit-id: 494e6a1e05b38f5ff61d83327303614f53c92e64
2024-04-16 17:29:19 +08:00
hoshi-hiyouga
96213f04b0 Update adapter.py
Former-commit-id: 8f7b75b26f020d8ae85baab7b082475c3bfeb512
2024-04-16 17:28:12 +08:00
hoshi-hiyouga
7ecea08b9b Update parser.py
Former-commit-id: 898239883afc79f03abd0dc276eef901662a9591
2024-04-16 17:27:25 +08:00
hoshi-hiyouga
191971865d Update parser.py
Former-commit-id: 2f3da8169d18b026760cc0ac7dd6141bdd08c932
2024-04-16 17:27:02 +08:00
hoshi-hiyouga
ff4f587dd9 Update finetuning_args.py
Former-commit-id: 3a23d900aea74078f0bc8cf73fac860a4ce3df67
2024-04-16 17:26:30 +08:00
hoshi-hiyouga
de728d0371 Update sft.sh
Former-commit-id: 2b4b1562e91bbb02e345e71b7721da9333c0791b
2024-04-16 17:25:40 +08:00
hoshi-hiyouga
d08e09642d Update requirements.txt
Former-commit-id: 1e45537ca0bb4d49b4147df01122e365b3d617e4
2024-04-16 17:10:17 +08:00
hoshi-hiyouga
351493b183 Update setup.py
Former-commit-id: 5df30ea166aff29d48ff83a22ac6ef1611ce3e35
2024-04-16 17:10:02 +08:00
Jonery
86ab47e121 remove badam from core requirements
Former-commit-id: fa5898944a3867ac5108dd0d579ca0677c87d3d6
2024-04-16 12:25:50 +08:00
Jonery
6dd6b3e396 resolve gradient checkpointing issue.
Former-commit-id: 6df9135d063bb6102f0cbcdf0d702076f5febbae
2024-04-16 12:05:27 +08:00
codingma
5f1418a68b add check
Former-commit-id: 008f6498977c243c80e87242f05c9cf9573541ac
2024-04-16 10:56:39 +08:00
codingma
7b97a79efc support for previewing custom dataset in directory format
Former-commit-id: 501cff38c819f06f15194907ce7e052d5f28025a
2024-04-16 10:43:14 +08:00
hiyouga
ce4f653121 add empty template
Former-commit-id: a325ffa8a668bec354d2636683806acef105e196
2024-04-16 03:10:02 +08:00
hiyouga
b053c6454e update readme
Former-commit-id: 8f233745c3aa7a6ef57f275bec80ee731ff76de3
2024-04-16 02:36:54 +08:00
hiyouga
ebf0f4a77c update readme
Former-commit-id: f9a246572c1ec0e4b36bff237c6523ce629b7000
2024-04-16 02:35:36 +08:00
hiyouga
efa808069a support unsloth 2024.4
Former-commit-id: 14a83f8bc4fe44783252378fce59198194a96bb8
2024-04-16 00:25:03 +08:00
hiyouga
b5c5283dd6 add codegemma
Former-commit-id: 9324176525c2eda22962b0ca1895009b6237e6e3
2024-04-16 00:11:15 +08:00
hiyouga
b638c65519 support cohere commandR #3184
Former-commit-id: e077c36872740f6b2ac255aee9da6c4c70f28977
2024-04-15 23:26:42 +08:00
Jonery
d4d471450f Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
2024-04-15 23:15:27 +08:00
hoshi-hiyouga
3144bdec2c Merge pull request #3254 from marko1616/feature/Add-support-for-CohereForAI/c4ai-command-r-plus
Add template&support for c4ai-command-r/plus (tested)

Former-commit-id: 41d39ec4889abad050820bf153133ac3a11228a3
2024-04-15 22:59:35 +08:00
hoshi-hiyouga
c6d6c4c209 Update template.py
Former-commit-id: 00b8be7dafa65e13b344724a8d3855919ee4f631
2024-04-15 22:58:01 +08:00
hoshi-hiyouga
f5f1589662 Update constants.py
Former-commit-id: 39199f712aa7b7a1c66080d9c84651fd2eb0b425
2024-04-15 22:56:55 +08:00
hiyouga
276f2cb24e update examples
Former-commit-id: 369294b31c8a03a1cafcee83eb31a817007d3c49
2024-04-15 22:14:34 +08:00
marko1616
952b785bb3 change default_system accroding to official template
Former-commit-id: 7ad9029c5e77a87a7c324b8f90b4f80a31a5c78b
2024-04-15 20:45:46 +08:00
marko1616
72dd676208 Revert "Add support for function call(Not strictly following origin)"
This reverts commit dfaa31e991 [formerly 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6].


Former-commit-id: fac9cc6e01dd8f3bc449b656804476e1871326f0
2024-04-15 20:27:09 +08:00
marko1616
dfaa31e991 Add support for function call(Not strictly following origin)
Former-commit-id: 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6
2024-04-15 20:16:52 +08:00
hoshi-hiyouga
86556b1c74 Merge pull request #3261 from khazic/main
Added specimens for single-card full parameter prediction

Former-commit-id: 60df2a9519fbd8215c3afacc831b0cc89006457a
2024-04-15 16:30:57 +08:00
hoshi-hiyouga
0c80751e87 Merge pull request #3276 from liu-zichen/fix_mixtral
fix: turn on output_router_logits of mixtral
Former-commit-id: 07bbaf5c67d00a152e5304e81b15fd9189e7bb99
2024-04-15 15:38:16 +08:00
hiyouga
9338f878a3 fix #3273
Former-commit-id: 3b20c89b342a068356ffc29c3724b645775c65db
2024-04-15 15:32:58 +08:00
liuzc
fde3d91242 fix: mixtral output_router_logits
Former-commit-id: ab3171ea97ec968b972287287ef9ee2502c6d37c
2024-04-15 12:11:49 +08:00
khazic
19adfb88a9 Upgrade README.md
Former-commit-id: 697f768d7185789ee054c94f4f161a65b8a505bc
2024-04-13 20:50:49 +08:00
khazic
daaafa900a Added specimens for single-card full parameter prediction
Former-commit-id: d8d4fb9fa4b0e1950a453682e5e186f34f085dee
2024-04-13 20:45:19 +08:00
marko1616
0dcc9e0bca Typo fix
Former-commit-id: 607625497738b2c8be736be7b0bd5c6f4cbaad5e
2024-04-13 17:30:21 +08:00
marko1616
aeec78b35c Typo fix
Former-commit-id: 51b1e49e288e66c1b0c24ac070201c988fb2a389
2024-04-13 07:52:11 +08:00
marko1616
c991654cb4 Add c4ai-command-r-plus link
Former-commit-id: acaf953ca46eca8fb378067f4ada133654e4f088
2024-04-13 07:32:40 +08:00
marko1616
f328413646 Add template&support(Not tested)
Former-commit-id: 60bb60c4dc30a9641ddb57a44ef126f0768566c4
2024-04-13 04:31:33 +08:00
hiyouga
106a0104da fix #3247
Former-commit-id: bb67c66f80627805b585d157ba807c0ce378d3f2
2024-04-12 17:41:33 +08:00
hiyouga
5486ea09e3 fix model card
Former-commit-id: 920e7149bf2b559c9829aa4b11cfb6d00bbb2f9e
2024-04-12 17:11:59 +08:00
hiyouga
31bbbb6d13 fix #3238
Former-commit-id: 4d7e81ab4722d13bec6ca1af141f94bdc74d0883
2024-04-12 14:28:11 +08:00
hiyouga
1a77de82fa set dev version
Former-commit-id: f6cc76571d2c789675883a18e0db3d0c61f33808
2024-04-11 20:27:34 +08:00
54 changed files with 871 additions and 367 deletions

View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-28-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-34-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -46,7 +46,7 @@ Choose your path:
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO. - **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -68,14 +68,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
[24/04/19] We supported **Meta Llama 3** model series.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
<details><summary>Full Changelog</summary>
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage. [24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv! [24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage. [24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage.
<details><summary>Full Changelog</summary>
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage. [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage. [24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
@@ -129,20 +137,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| Model | Model size | Default module | Template | | Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
@@ -241,6 +251,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
@@ -275,16 +286,15 @@ huggingface-cli login
\* *estimated* \* *estimated*
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B | | Method | Bits | 7B | 13B | 30B | 70B | 8x7B | 8x22B |
| ------ | ---- | ----- | ----- | ----- | ------ | ------ | | ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
## Getting Started ## Getting Started
@@ -305,7 +315,7 @@ cd LLaMA-Factory
pip install -e .[metrics] pip install -e .[metrics]
``` ```
Extra dependencies available: deepspeed, metrics, unsloth, galore, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality Extra dependencies available: deepspeed, metrics, unsloth, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
<details><summary>For Windows users</summary> <details><summary>For Windows users</summary>
@@ -328,6 +338,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
```bash ```bash
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
python src/train_web.py # or python -m llmtuner.webui.interface python src/train_web.py # or python -m llmtuner.webui.interface
``` ```
@@ -413,8 +424,14 @@ If you have a project that should be incorporated, please contact via email or c
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246) 1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) 1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
@@ -427,7 +444,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-28-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-34-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -46,7 +46,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。 - **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -68,14 +68,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`
[24/04/19] 我们支持了 **Meta Llama 3** 系列模型。
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
<details><summary>展开日志</summary>
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu` [24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看! [24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora` [24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora`
<details><summary>展开日志</summary>
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus` [24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore` [24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`
@@ -129,20 +137,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
@@ -241,6 +251,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
@@ -275,16 +286,15 @@ huggingface-cli login
\* *估算值* \* *估算值*
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B | | 方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B | 8x22B |
| ------- | ---- | ----- | ----- | ----- | ------ | ------ | | ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ |
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB |
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB |
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
## 如何使用 ## 如何使用
@@ -305,7 +315,7 @@ cd LLaMA-Factory
pip install -e .[metrics] pip install -e .[metrics]
``` ```
可选的额外依赖项deepspeed、metrics、unsloth、galore、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality 可选的额外依赖项deepspeed、metrics、unsloth、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
<details><summary>Windows 用户指南</summary> <details><summary>Windows 用户指南</summary>
@@ -328,6 +338,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
```bash ```bash
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0` export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
python src/train_web.py # 或 python -m llmtuner.webui.interface python src/train_web.py # 或 python -m llmtuner.webui.interface
``` ```
@@ -388,7 +399,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
如果您有项目希望添加至述列表,请通过邮件联系或者创建一个 PR。 如果您有项目希望添加至述列表,请通过邮件联系或者创建一个 PR。
<details><summary>点击显示</summary> <details><summary>点击显示</summary>
@@ -413,8 +424,14 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246) 1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) 1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
@@ -427,7 +444,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

View File

@@ -1 +1 @@
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b a97cf9475291591843976554878568e046d8a46d

View File

@@ -1,5 +1,6 @@
import os
import json import json
import os
import datasets import datasets
@@ -22,31 +23,19 @@ _URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0
class BelleMultiturn(datasets.GeneratorBasedBuilder): class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self): def _info(self):
features = datasets.Features({ features = datasets.Features(
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}] {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
}) )
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager): def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download(_URL) file_path = dl_manager.download(_URL)
return [ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
def _generate_examples(self, filepath: str): def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:
@@ -58,7 +47,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:") assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:") human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip() query = prompt[human_idx + 6 : assist_idx].strip()
prompt = prompt[:human_idx].strip() prompt = prompt[:human_idx].strip()
conversations.insert(0, {"from": "gpt", "value": response}) conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query}) conversations.insert(0, {"from": "human", "value": query})
@@ -67,8 +56,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:") assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:") human_idx = prompt.rfind("Human:")
if human_idx != -1: if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip() old_query = prompt[human_idx + 6 : assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip() old_resp = prompt[assist_idx + 10 :].strip()
conversations.insert(0, {"from": "gpt", "value": old_resp}) conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query}) conversations.insert(0, {"from": "human", "value": old_query})
else: else:

View File

@@ -1,7 +1,8 @@
import json import json
import datasets
from typing import Any, Dict, Generator, List, Tuple from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset." _DESCRIPTION = "An example of dataset."
_CITATION = "" _CITATION = ""
@@ -11,34 +12,24 @@ _URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder): class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo: def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({ features = datasets.Features(
"instruction": datasets.Value("string"), {
"input": datasets.Value("string"), "instruction": datasets.Value("string"),
"output": datasets.Value("string"), "input": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) "output": datasets.Value("string"),
}) "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL) file_path = dl_manager.download(_URL)
return [ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]: def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8")) example_dataset = json.load(open(filepath, "r", encoding="utf-8"))

View File

@@ -1,8 +1,10 @@
import os
import json import json
import datasets import os
from typing import List from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness." _DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = "" _CITATION = ""
@@ -14,50 +16,37 @@ _URLS = {
_URL + "harmless-base/train.jsonl.gz", _URL + "harmless-base/train.jsonl.gz",
_URL + "helpful-base/train.jsonl.gz", _URL + "helpful-base/train.jsonl.gz",
_URL + "helpful-online/train.jsonl.gz", _URL + "helpful-online/train.jsonl.gz",
_URL + "helpful-rejection-sampled/train.jsonl.gz" _URL + "helpful-rejection-sampled/train.jsonl.gz",
], ],
"test": [ "test": [
_URL + "harmless-base/test.jsonl.gz", _URL + "harmless-base/test.jsonl.gz",
_URL + "helpful-base/test.jsonl.gz", _URL + "helpful-base/test.jsonl.gz",
_URL + "helpful-online/test.jsonl.gz", _URL + "helpful-online/test.jsonl.gz",
_URL + "helpful-rejection-sampled/test.jsonl.gz" _URL + "helpful-rejection-sampled/test.jsonl.gz",
] ],
} }
class HhRlhfEn(datasets.GeneratorBasedBuilder): class HhRlhfEn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo: def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({ features = datasets.Features(
"instruction": datasets.Value("string"), {
"output": datasets.Sequence(datasets.Value("string")), "instruction": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) "output": datasets.Sequence(datasets.Value("string")),
}) "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager): def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download_and_extract(_URLS) file_path = dl_manager.download_and_extract(_URLS)
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
name=datasets.Split.TRAIN, datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
gen_kwargs={
"filepaths": file_path["train"]
}
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepaths": file_path["test"]
}
)
] ]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
@@ -70,12 +59,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
rejected = data["rejected"] rejected = data["rejected"]
assist_idx = rejected.rfind("\n\nAssistant: ") assist_idx = rejected.rfind("\n\nAssistant: ")
r_reject = rejected[assist_idx+13:].strip() r_reject = rejected[assist_idx + 13 :].strip()
assist_idx = chosen.rfind("\n\nAssistant: ") assist_idx = chosen.rfind("\n\nAssistant: ")
r_accept = chosen[assist_idx+13:].strip() r_accept = chosen[assist_idx + 13 :].strip()
human_idx = chosen.rfind("\n\nHuman: ") human_idx = chosen.rfind("\n\nHuman: ")
query = chosen[human_idx+9:assist_idx].strip() query = chosen[human_idx + 9 : assist_idx].strip()
prompt = chosen[:human_idx] prompt = chosen[:human_idx]
history = [] history = []
@@ -83,16 +72,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("\n\nAssistant: ") assist_idx = prompt.rfind("\n\nAssistant: ")
human_idx = prompt.rfind("\n\nHuman: ") human_idx = prompt.rfind("\n\nHuman: ")
if human_idx != -1: if human_idx != -1:
old_query = prompt[human_idx+9:assist_idx].strip() old_query = prompt[human_idx + 9 : assist_idx].strip()
old_resp = prompt[assist_idx+13:].strip() old_resp = prompt[assist_idx + 13 :].strip()
history.insert(0, (old_query, old_resp)) history.insert(0, (old_query, old_resp))
else: else:
break break
prompt = prompt[:human_idx] prompt = prompt[:human_idx]
yield key, { yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
"instruction": query,
"output": [r_accept, r_reject],
"history": history
}
key += 1 key += 1

View File

@@ -1,8 +1,10 @@
import os
import json import json
import datasets import os
from typing import List from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data." _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
@@ -24,31 +26,19 @@ _BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jso
class UltraChat(datasets.GeneratorBasedBuilder): class UltraChat(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self): def _info(self):
features = datasets.Features({ features = datasets.Features(
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}] {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
}) )
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager): def _split_generators(self, dl_manager: datasets.DownloadManager):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_paths
}
)
]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths: for filepath in filepaths:
@@ -56,7 +46,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
for row in f: for row in f:
try: try:
data = json.loads(row) data = json.loads(row)
except: except Exception:
continue continue
key: int = data["id"] key: int = data["id"]
content: List[str] = data["data"] content: List[str] = data["data"]
@@ -64,8 +54,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
content.pop(-1) content.pop(-1)
if len(content) < 2: if len(content) < 2:
continue continue
conversations = [{ conversations = [
"from": "human" if i % 2 == 0 else "gpt", {"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
"value": content[i] ]
} for i in range(len(content))]
yield key, {"conversations": conversations} yield key, {"conversations": conversations}

View File

@@ -3,41 +3,46 @@ We provide diverse examples about fine-tuning LLMs.
``` ```
examples/ examples/
├── lora_single_gpu/ ├── lora_single_gpu/
│ ├── pretrain.sh: Do pre-training │ ├── pretrain.sh: Do continuous pre-training using LoRA
│ ├── sft.sh: Do supervised fine-tuning │ ├── sft.sh: Do supervised fine-tuning using LoRA
│ ├── reward.sh: Do reward modeling │ ├── reward.sh: Do reward modeling using LoRA
│ ├── ppo.sh: Do PPO training │ ├── ppo.sh: Do PPO training using LoRA
│ ├── dpo.sh: Do DPO training │ ├── dpo.sh: Do DPO training using LoRA
│ ├── orpo.sh: Do ORPO training │ ├── orpo.sh: Do ORPO training using LoRA
│ ├── prepare.sh: Save tokenized dataset │ ├── prepare.sh: Save tokenized dataset
│ └── predict.sh: Do batch predict │ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
├── qlora_single_gpu/ ├── qlora_single_gpu/
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models │ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models using QLoRA
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models │ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models using QLoRA
│ ├── awq.sh: Fine-tune 4-bit AWQ models │ ├── awq.sh: Fine-tune 4-bit AWQ models using QLoRA
│ └── aqlm.sh: Fine-tune 2-bit AQLM models │ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
├── lora_multi_gpu/ ├── lora_multi_gpu/
│ ├── single_node.sh: Fine-tune model with Accelerate on single node │ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes │ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
├── full_multi_gpu/ ├── full_multi_gpu/
│ ├── single_node.sh: Fine-tune model with DeepSpeed on single node │ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
── multi_node.sh: Fine-tune model with DeepSpeed on multiple nodes ── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after full tuning
├── merge_lora/ ├── merge_lora/
│ ├── merge.sh: Merge LoRA weights into the pre-trained models │ ├── merge.sh: Merge LoRA weights into the pre-trained models
│ └── quantize.sh: Quantize fine-tuned model with AutoGPTQ │ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
├── inference/ ├── inference/
│ ├── cli_demo.sh: Launch a command line interface │ ├── cli_demo.sh: Launch a command line interface with LoRA adapters
│ ├── api_demo.sh: Launch an OpenAI-style API │ ├── api_demo.sh: Launch an OpenAI-style API with LoRA adapters
│ ├── web_demo.sh: Launch a web interface │ ├── web_demo.sh: Launch a web interface with LoRA adapters
│ └── evaluate.sh: Evaluate model on the MMLU benchmark │ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
└── extras/ └── extras/
├── galore/ ├── galore/
│ └── sft.sh: Fine-tune model with GaLore │ └── sft.sh: Fine-tune model with GaLore
├── badam/
│ └── sft.sh: Fine-tune model with BAdam
├── loraplus/ ├── loraplus/
│ └── sft.sh: Fine-tune model with LoRA+ │ └── sft.sh: Fine-tune model using LoRA+
├── mod/
│ └── sft.sh: Fine-tune model using Mixture-of-Depths
├── llama_pro/ ├── llama_pro/
│ ├── expand.sh: Expand layers in the model │ ├── expand.sh: Expand layers in the model
│ └── sft.sh: Fine-tune expanded model │ └── sft.sh: Fine-tune the expanded model
└── fsdp_qlora/ └── fsdp_qlora/
└── sft.sh: Fine-tune quantized model with FSDP └── sft.sh: Fine-tune quantized model with FSDP+QLoRA
``` ```

View File

@@ -1,43 +1,48 @@
我们提供了多样化的示例脚本。 我们提供了多样化的大模型微调示例脚本。
``` ```
examples/ examples/
├── lora_single_gpu/ ├── lora_single_gpu/
│ ├── pretrain.sh: 进行预训练 │ ├── pretrain.sh: 基于 LoRA 进行增量预训练
│ ├── sft.sh: 进行指令监督微调 │ ├── sft.sh: 基于 LoRA 进行指令监督微调
│ ├── reward.sh: 进行奖励模型训练 │ ├── reward.sh: 基于 LoRA 进行奖励模型训练
│ ├── ppo.sh: 进行 PPO 训练 │ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
│ ├── dpo.sh: 进行 DPO 训练 │ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
│ ├── orpo.sh: 进行 ORPO 训练 │ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
│ ├── prepare.sh: 保存预处理后的数据集 │ ├── prepare.sh: 保存预处理后的数据集
│ └── predict.sh: 进行批量预测 │ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
├── qlora_single_gpu/ ├── qlora_single_gpu/
│ ├── bitsandbytes.sh: 微调 4/8 比特 BNB 模型 │ ├── bitsandbytes.sh: 基于 QLoRA 微调 4/8 比特 BNB 模型
│ ├── gptq.sh: 微调 4/8 比特 GPTQ 模型 │ ├── gptq.sh: 基于 QLoRA 微调 4/8 比特 GPTQ 模型
│ ├── awq.sh: 微调 4 比特 AWQ 模型 │ ├── awq.sh: 基于 QLoRA 微调 4 比特 AWQ 模型
│ └── aqlm.sh: 微调 2 比特 AQLM 模型 │ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
├── lora_multi_gpu/ ├── lora_multi_gpu/
│ ├── single_node.sh: 使用 Accelerate 进行单节点训练 │ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
│ └── multi_node.sh: 使用 Accelerate 进行多节点训练 │ └── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
├── full_multi_gpu/ ├── full_multi_gpu/
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点训练 │ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
── multi_node.sh: 使用 DeepSpeed 进行多节点训练 ── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
│ └── predict.sh: 基于全量训练进行批量预测并计算 BLEU 和 ROUGE 分数
├── merge_lora/ ├── merge_lora/
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中 │ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
│ └── quantize.sh: 使用 AutoGPTQ 量化模型 │ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型
├── inference/ ├── inference/
│ ├── cli_demo.sh: 启动命令行推理接口 │ ├── cli_demo.sh: 启动 LoRA 模型的命令行推理接口
│ ├── api_demo.sh: 启动 OpenAI 风格 API │ ├── api_demo.sh: 启动 LoRA 模型的 OpenAI 风格 API
│ ├── web_demo.sh: 启动浏览器推理接口 │ ├── web_demo.sh: 启动 LoRA 模型的浏览器推理接口
│ └── evaluate.sh: 在 MMLU 数据集上评测模型 │ └── evaluate.sh: 在 MMLU/CMMLU/C-Eval 数据集上评测 LoRA 模型
└── extras/ └── extras/
├── galore/ ├── galore/
│ └── sft.sh: 使用 GaLore 训练模型 │ └── sft.sh: 使用 GaLore 训练模型
├── badam/
│ └── sft.sh: 使用 BAdam 训练模型
├── loraplus/ ├── loraplus/
│ └── sft.sh: 使用 LoRA+ 训练模型 │ └── sft.sh: 使用 LoRA+ 训练模型
├── mod/
│ └── sft.sh: 使用深度混合训练模型
├── llama_pro/ ├── llama_pro/
│ ├── expand.sh: 扩展模型中的层 │ ├── expand.sh: 扩展模型中的层
│ └── sft.sh: 训练扩展后的模型 │ └── sft.sh: 训练扩展后的模型
└── fsdp_qlora/ └── fsdp_qlora/
└── sft.sh: 使用 FSDP 微调量化模型 └── sft.sh: 使用 FSDP+QLoRA 微调量化模型
``` ```

View File

@@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--mixture_of_depths convert \
--output_dir ../../../saves/LLaMA2-7B/mod/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--optim paged_adamw_8bit \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@@ -0,0 +1,35 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--use_badam \
--badam_switch_mode descending \
--badam_switch_block_every 50 \
--badam_verbose 2 \
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@@ -12,6 +12,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--galore_layerwise \ --galore_layerwise \
--galore_target mlp,self_attn \ --galore_target mlp,self_attn \
--galore_rank 128 \ --galore_rank 128 \
--galore_scale 2.0 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \ --output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \

View File

@@ -9,6 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--loraplus_lr_ratio 16.0 \
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \ --output_dir ../../saves/LLaMA2-7B/loraplus/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
@@ -29,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--plot_loss \ --plot_loss \
--fp16 \ --fp16
--loraplus_lr_ratio 16.0

View File

@@ -0,0 +1,18 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type full \
--output_dir ../../saves/LLaMA2-7B/full/predict \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_eval_batch_size 1 \
--max_samples 20 \
--predict_with_generate

View File

@@ -3,7 +3,7 @@
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template vanilla \ --template fewshot \
--finetuning_type lora \ --finetuning_type lora \
--task mmlu \ --task mmlu \
--split test \ --split test \

View File

@@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
# DO NOT use quantized model or quantization_bit when merging lora weights # DO NOT use quantized model or quantization_bit when merging lora weights
CUDA_VISIBLE_DEVICES= python ../../src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \ --template default \

View File

@@ -4,7 +4,7 @@ datasets>=2.14.3
accelerate>=0.27.2 accelerate>=0.27.2
peft>=0.10.0 peft>=0.10.0
trl>=0.8.1 trl>=0.8.1
gradio>=4.0.0,<=4.21.0 gradio>=4.0.0
scipy scipy
einops einops
sentencepiece sentencepiece

View File

@@ -24,6 +24,7 @@ extra_require = {
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"], "unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam"],
"vllm": ["vllm>=0.3.3"], "vllm": ["vllm>=0.3.3"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],

View File

@@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo from .webui import create_ui, create_web_demo
__version__ = "0.6.2" __version__ = "0.6.3"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -5,14 +5,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass @dataclass
class Response: class Response:

View File

@@ -1,8 +1,6 @@
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from transformers.utils.versions import require_version
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available
@@ -25,7 +23,6 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path, model=model_args.model_name_or_path,

View File

@@ -343,7 +343,7 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None, name: Optional[str] = None,
) -> Template: ) -> Template:
if name is None: if name is None:
template = templates["vanilla"] # placeholder template = templates["empty"] # placeholder
else: else:
template = templates.get(name, None) template = templates.get(name, None)
if template is None: if template is None:
@@ -385,7 +385,8 @@ _register_template(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=( default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request." "Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
), ),
) )
@@ -526,6 +527,21 @@ _register_template(
) )
_register_template(
name="cohere",
format_user=StringFormatter(
slots=[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
)
_register_template( _register_template(
name="cpm", name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]), format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@@ -566,6 +582,13 @@ _register_template(
) )
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -574,6 +597,13 @@ _register_template(
) )
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template( _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
@@ -635,9 +665,28 @@ _register_template(
) )
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
default_system="You are a helpful assistant.",
stop_words=["<|eot_id|>"],
replace_eos=True,
)
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True, force_system=True,
) )
@@ -699,13 +748,6 @@ _register_template(
) )
_register_template(
name="vanilla",
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)
_register_template( _register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),

View File

@@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"] PEFT_METHODS = ["lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
@@ -242,6 +244,28 @@ register_model_group(
) )
register_model_group(
models={
"CommandR-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
},
"CommandR-Plus-104B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
},
"CommandR-35B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
},
"CommandR-Plus-104B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
},
},
template="cohere",
)
register_model_group( register_model_group(
models={ models={
"DeepSeek-LLM-7B-Base": { "DeepSeek-LLM-7B-Base": {
@@ -363,6 +387,23 @@ register_model_group(
) )
register_model_group(
models={
"CodeGemma-2B": {
DownloadSource.DEFAULT: "google/codegemma-2b",
},
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
},
template="gemma",
)
register_model_group( register_model_group(
models={ models={
"InternLM-7B": { "InternLM-7B": {
@@ -474,6 +515,29 @@ register_model_group(
) )
register_model_group(
models={
"LLaMA3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
},
"LLaMA3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
},
"LLaMA3-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
},
"LLaMA3-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
},
},
template="llama3",
)
register_model_group( register_model_group(
models={ models={
"Mistral-7B-v0.1": { "Mistral-7B-v0.1": {
@@ -499,14 +563,20 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Mixtral-8x7B": { "Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
}, },
"Mixtral-8x7B-Chat": { "Mixtral-8x7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
}, },
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
},
"Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
},
}, },
template="mistral", template="mistral",
) )
@@ -688,6 +758,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
}, },
"Qwen1.5-Code-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"Qwen1.5-0.5B-Chat": { "Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
@@ -720,6 +794,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
}, },
"Qwen1.5-Code-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": { "Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
@@ -776,6 +854,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
}, },
"Qwen1.5-Code-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
},
}, },
template="qwen", template="qwen",
) )
@@ -979,18 +1061,3 @@ register_model_group(
}, },
template="zephyr", template="zephyr",
) )
register_model_group(
models={
"Atom-7B": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
},
"Atom-7B-Chat": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
},
},
template="atom",
)

View File

@@ -66,7 +66,6 @@ def check_dependencies() -> None:
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
require_version("gradio>=4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
@@ -84,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if param.__class__.__name__ == "Params4bit": if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else: else:
num_bytes = 1 num_bytes = 1

View File

@@ -25,6 +25,10 @@ def is_galore_available():
return _is_package_available("galore_torch") return _is_package_available("galore_torch")
def is_gradio_available():
return _is_package_available("gradio")
def is_jieba_available(): def is_jieba_available():
return _is_package_available("jieba") return _is_package_available("jieba")
@@ -49,10 +53,6 @@ def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")

View File

@@ -172,7 +172,7 @@ class GaloreArguments:
use_galore: bool = field( use_galore: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."}, metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
) )
galore_target: str = field( galore_target: str = field(
default="all", default="all",
@@ -204,7 +204,54 @@ class GaloreArguments:
@dataclass @dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments): class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_update_ratio: float = field(
default=0.0,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": """The mode of the mask for BAdam optimizer. \
`adjacent` means that the trainable parameters are adjacent to each other, \
`scatter` means that trainable parameters are randomly choosed from the weight."""
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": """The verbosity level of BAdam optimizer. \
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
@@ -256,11 +303,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full": if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.") raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
if self.use_galore and self.finetuning_type == "lora": if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.") raise ValueError("Cannot use LoRA with GaLore together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"

View File

@@ -55,7 +55,7 @@ class ModelArguments:
) )
quantization_device_map: Optional[Literal["auto"]] = field( quantization_device_map: Optional[Literal["auto"]] = field(
default=None, default=None,
metadata={"help": "Device map used for loading the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
@@ -63,12 +63,16 @@ class ModelArguments:
) )
flash_attn: bool = field( flash_attn: bool = field(
default=False, default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."}, metadata={"help": "Enable FlashAttention for faster training."},
) )
shift_attn: bool = field( shift_attn: bool = field(
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field( use_unsloth: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
@@ -129,6 +133,10 @@ class ModelArguments:
default=1, default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}, metadata={"help": "The file shard size (in GB) of the exported model."},
) )
export_device: str = field(
default="cpu",
metadata={"help": "The device used in model export."},
)
export_quantization_bit: Optional[int] = field( export_quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the exported model."}, metadata={"help": "The number of bits to quantize the exported model."},

View File

@@ -8,10 +8,10 @@ import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import check_dependencies from ..extras.misc import check_dependencies, get_current_device
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@@ -74,6 +74,32 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
@@ -131,8 +157,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and training_args.predict_with_generate: if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.") raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available(): if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth") raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.use_dora and model_args.use_unsloth: if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.") raise ValueError("Unsloth does not support DoRA.")
@@ -151,13 +177,21 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
): ):
raise ValueError("Distributed training does not support layer-wise GaLore.") raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_galore and training_args.deepspeed is not None: if (
raise ValueError("GaLore is incompatible with DeepSpeed.") finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if ( if (
training_args.do_train training_args.do_train
@@ -235,6 +269,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
elif training_args.fp16: elif training_args.fp16:
model_args.compute_dtype = torch.float16 model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
@@ -276,8 +311,12 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
raise ValueError("vLLM engine does not support RoPE scaling.") raise ValueError("vLLM engine does not support RoPE scaling.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto" if model_args.export_dir is not None:
model_args.device_map = {"": torch.device(model_args.export_device)}
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
@@ -294,6 +333,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto" model_args.device_map = "auto"

View File

@@ -32,9 +32,12 @@ def init_adapter(
logger.info("Adapter is not found at evaluation, load the base model.") logger.info("Adapter is not found at evaluation, load the base model.")
return model return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if finetuning_args.finetuning_type == "full" and is_trainable: if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16: if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float() model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable: if finetuning_args.finetuning_type == "freeze" and is_trainable:
@@ -66,6 +69,8 @@ def init_adapter(
for name, _ in model.named_modules(): for name, _ in model.named_modules():
if ".0." in name: if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0]) freeze_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
trainable_layers = [] trainable_layers = []
for module_name in finetuning_args.name_module_trainable: for module_name in finetuning_args.name_module_trainable:
@@ -79,7 +84,7 @@ def init_adapter(
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers): if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16: if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
else: else:
param.requires_grad_(False) param.requires_grad_(False)
@@ -129,9 +134,12 @@ def init_adapter(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None: if (
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES: finetuning_args.use_dora
raise ValueError("DoRA is not compatible with PTQ-quantized models.") and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
@@ -139,24 +147,28 @@ def init_adapter(
"lora_alpha": finetuning_args.lora_alpha, "lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout, "lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora, "use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
} }
if model_args.use_unsloth: if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else: else:
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
inference_mode=False, inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora, use_dora=finetuning_args.use_dora,
**peft_kwargs, **peft_kwargs,
) )
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
if not finetuning_args.pure_bf16: if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.constants import MOD_SUPPORTED_MODELS
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
@@ -36,13 +37,22 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
tokenizer = AutoTokenizer.from_pretrained( try:
model_args.model_name_or_path, tokenizer = AutoTokenizer.from_pretrained(
use_fast=model_args.use_fast_tokenizer, model_args.model_name_or_path,
split_special_tokens=model_args.split_special_tokens, use_fast=model_args.use_fast_tokenizer,
padding_side="right", split_special_tokens=model_args.split_special_tokens,
**init_kwargs, padding_side="right",
) **init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
return tokenizer return tokenizer
@@ -73,6 +83,8 @@ def load_model(
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
"device_map": {"": get_current_device()}, "device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None), "rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
} }
try: try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
@@ -85,7 +97,24 @@ def load_model(
logger.warning("Unsloth does not support loading adapters.") logger.warning("Unsloth does not support loading adapters.")
if model is None: if model is None:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs) init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load":
from MoD import AutoMoDModelForCausalLM
model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
patch_model(model, tokenizer, model_args, is_trainable) patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer) register_autoclass(config, model, tokenizer)
@@ -93,7 +122,7 @@ def load_model(
model = init_adapter(model, model_args, finetuning_args, is_trainable) model = init_adapter(model, model_args, finetuning_args, is_trainable)
if add_valuehead: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model) patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:

View File

@@ -17,7 +17,7 @@ from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch from ..extras.patches.llama_patch import apply_llama_patch
from .utils import QuantizationMethod, add_z3_leaf_module from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -61,9 +61,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
return samples return samples
def _configure_attn_implementation( def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
) -> None:
if model_args.flash_attn: if model_args.flash_attn:
if not is_flash_attn2_available(): if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.") logger.warning("FlashAttention2 is not installed.")
@@ -73,9 +71,9 @@ def _configure_attn_implementation(
if getattr(config, "model_type", None) == "internlm2": # special case for custom models if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", "flash_attention_2") setattr(config, "attn_implementation", "flash_attention_2")
else: else:
init_kwargs["attn_implementation"] = "flash_attention_2" setattr(config, "_attn_implementation", "flash_attention_2")
else: else:
init_kwargs["attn_implementation"] = "eager" setattr(config, "_attn_implementation", "eager")
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
@@ -133,12 +131,15 @@ def _configure_quantization(
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
init_kwargs["device_map"] = {"": get_current_device()} if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ: if quant_method == QuantizationMethod.AWQ:
@@ -266,8 +267,8 @@ def _prepare_model_for_training(
else: else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339 # According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")
@@ -293,7 +294,7 @@ def patch_config(
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
_configure_attn_implementation(config, model_args, init_kwargs) _configure_attn_implementation(config, model_args)
_configure_rope(config, model_args, is_trainable) _configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable) _configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs) _configure_quantization(config, tokenizer, model_args, init_kwargs)
@@ -316,15 +317,15 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable: if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable:
setattr(config, "output_router_logits", True) setattr(config, "output_router_logits", True)
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]: if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs: if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} init_kwargs["device_map"] = model_args.device_map
if init_kwargs["device_map"] == "auto": if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder init_kwargs["offload_folder"] = model_args.offload_folder

View File

@@ -1,5 +1,7 @@
import inspect
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
@@ -100,6 +102,42 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
return module_names return module_names
def gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r""" r"""
Loads value head parameters from Hugging Face Hub or local disk. Loads value head parameters from Hugging Face Hub or local disk.

View File

@@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch import torch
@@ -63,6 +64,11 @@ class CustomDPOTrainer(DPOTrainer):
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)

View File

@@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch import torch
@@ -44,6 +45,10 @@ class CustomORPOTrainer(DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@@ -1,6 +1,7 @@
import math import math
import os import os
import sys import sys
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch import torch
@@ -124,6 +125,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else: else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r""" r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.

View File

@@ -1,3 +1,4 @@
from types import MethodType
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
@@ -23,6 +24,10 @@ class CustomTrainer(Trainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@@ -1,5 +1,6 @@
import json import json
import os import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
@@ -28,6 +29,10 @@ class PairwiseTrainer(Trainer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import numpy as np import numpy as np
from transformers.utils.versions import require_version
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
@@ -33,10 +32,6 @@ class ComputeMetrics:
r""" r"""
Uses the model predictions to compute metrics. Uses the model predictions to compute metrics.
""" """
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
preds, labels = eval_preds preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}

View File

@@ -1,5 +1,6 @@
import json import json
import os import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
@@ -28,6 +29,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@@ -65,8 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16) output_dtype = getattr(model.config, "torch_dtype", torch.float16)
setattr(model.config, "torch_dtype", output_dtype) setattr(model.config, "torch_dtype", output_dtype)
for param in model.parameters(): model = model.to(output_dtype)
param.data = param.data.to(output_dtype)
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,

View File

@@ -5,7 +5,6 @@ from transformers import Trainer
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available
@@ -57,9 +56,11 @@ def create_modelcard_and_push(
kwargs = { kwargs = {
"tasks": "text-generation", "tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory", finetuning_args.finetuning_type], "tags": ["llama-factory", finetuning_args.finetuning_type],
} }
if data_args.dataset is not None:
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
if not training_args.do_train: if not training_args.do_train:
pass pass
elif training_args.push_to_hub: elif training_args.push_to_hub:
@@ -166,8 +167,6 @@ def _create_galore_optimizer(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
require_version("galore_torch", "To fix: pip install galore_torch")
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model) galore_targets = find_all_linear_modules(model)
else: else:
@@ -217,7 +216,7 @@ def _create_galore_optimizer(
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params: for param in nodecay_params:
param_groups = [dict(params=[param])] param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params: for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
@@ -237,7 +236,7 @@ def _create_galore_optimizer(
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else: else:
param_groups = [ param_groups = [
dict(params=nodecay_params), dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay), dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
] ]
@@ -252,11 +251,9 @@ def _create_loraplus_optimizer(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
if finetuning_args.finetuning_type != "lora": default_lr = training_args.learning_rate
raise ValueError("You should use LoRA tuning to activate LoRA+.")
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
decay_args = {"weight_decay": training_args.weight_decay} embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model) decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = { param_dict: Dict[str, List["torch.nn.Parameter"]] = {
@@ -279,16 +276,76 @@ def _create_loraplus_optimizer(
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [ param_groups = [
dict(params=param_dict["lora_a"], **decay_args), dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args), dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr), dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args), dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
return optimizer return optimizer
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_block_every,
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
named_parameters_list=list(model.named_parameters()),
update_ratio=finetuning_args.badam_update_ratio,
mask_mode=finetuning_args.badam_mask_mode,
verbose=finetuning_args.badam_verbose,
include_embedding=False,
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
return optimizer
def create_custom_optimzer( def create_custom_optimzer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
@@ -300,6 +357,9 @@ def create_custom_optimzer(
if finetuning_args.loraplus_lr_ratio is not None: if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args) return _create_loraplus_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
def create_custom_scheduler( def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
@@ -314,13 +374,12 @@ def create_custom_scheduler(
scheduler_dict[param] = get_scheduler( scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer_dict[param], optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2, num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps * 2, num_training_steps=num_training_steps,
) )
def scheduler_hook(param: "torch.nn.Parameter"): def scheduler_hook(param: "torch.nn.Parameter"):
if param.grad is not None: scheduler_dict[param].step()
scheduler_dict[param].step()
for param in optimizer_dict.keys(): for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook) param.register_post_accumulate_grad_hook(scheduler_hook)

View File

@@ -1,13 +1,11 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir from .common import get_save_dir
from .locales import ALERTS from .locales import ALERTS
@@ -17,6 +15,10 @@ if TYPE_CHECKING:
from .manager import Manager from .manager import Manager
if is_gradio_available():
import gradio as gr
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
@@ -35,7 +37,7 @@ class WebChatModel(ChatModel):
def loaded(self) -> bool: def loaded(self) -> bool:
return self.engine is not None return self.engine is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang") lang = get("top.lang")
error = "" error = ""
@@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def unload_model(self, data) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_id("top.lang")] lang = data[self.manager.get_elem_by_id("top.lang")]
if self.demo_mode: if self.demo_mode:

View File

@@ -3,7 +3,6 @@ import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import ( from ..extras.constants import (
@@ -17,6 +16,11 @@ from ..extras.constants import (
DownloadSource, DownloadSource,
) )
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}

View File

@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
from ...data import Role from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema from ..utils import check_json_schema
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@@ -1,10 +1,13 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import gradio as gr
from ...extras.constants import DATA_CONFIG from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -29,28 +32,38 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
except Exception: except Exception:
return gr.Button(interactive=False) return gr.Button(interactive=False)
if ( if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
len(dataset) > 0 return gr.Button(interactive=False)
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
): if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
return gr.Button(interactive=True) return gr.Button(interactive=True)
else: else:
return gr.Button(interactive=False) return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
return [json.loads(line) for line in f]
else:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_file: str = dataset_info[dataset[0]]["file_name"] data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: if os.path.isfile(data_path):
if data_file.endswith(".json"): data = _load_data_file(data_path)
data = json.load(f) else:
elif data_file.endswith(".jsonl"): data = []
data = [json.loads(line) for line in f] for file_name in os.listdir(data_path):
else: data.extend(_load_data_file(os.path.join(data_path, file_name)))
data = [line for line in f] # noqa: C416
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)

View File

@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box from .data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -18,7 +21,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset) preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({dataset_dir, dataset}) input_elems.update({dataset_dir, dataset})

View File

@@ -1,12 +1,15 @@
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr from ...extras.packages import is_gradio_available
from ...train import export_model from ...train import export_model
from ..common import get_save_dir from ..common import get_save_dir
from ..locales import ALERTS from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@@ -1,10 +1,13 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr from ...extras.packages import is_gradio_available
from .chatbot import create_chat_box from .chatbot import create_chat_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@@ -1,13 +1,16 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, list_adapters, save_config from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize from ..utils import can_quantize
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@@ -1,13 +1,17 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box from ..components.data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -23,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1 choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
) )
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset) preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset}) input_elems.update({training_stage, dataset_dir, dataset})

View File

@@ -1,6 +1,4 @@
from typing import Any, Dict, Generator from typing import TYPE_CHECKING, Any, Dict
from gradio.components import Component # cannot use TYPE_CHECKING here
from .chatter import WebChatModel from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config from .common import get_model_path, list_dataset, load_config
@@ -10,6 +8,10 @@ from .runner import Runner
from .utils import get_time from .utils import get_time
if TYPE_CHECKING:
from gradio.components import Component
class Engine: class Engine:
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode self.demo_mode = demo_mode
@@ -29,7 +31,7 @@ class Engine:
return output_dict return output_dict
def resume(self) -> Generator[Dict[Component, Component], None, None]: def resume(self):
user_config = load_config() if not self.demo_mode else {} user_config = load_config() if not self.demo_mode else {}
lang = user_config.get("lang", None) or "en" lang = user_config.get("lang", None) or "en"
@@ -55,7 +57,7 @@ class Engine:
else: else:
yield self._update_component({"eval.resume_btn": {"value": True}}) yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str) -> Dict[Component, Component]: def change_lang(self, lang: str):
return { return {
elem: elem.__class__(**LOCALES[elem_name][lang]) elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter() for elem_name, elem in self.manager.get_elem_iter()

View File

@@ -1,5 +1,4 @@
import gradio as gr from ..extras.packages import is_gradio_available
from .common import save_config from .common import save_config
from .components import ( from .components import (
create_chat_box, create_chat_box,
@@ -13,6 +12,10 @@ from .css import CSS
from .engine import Engine from .engine import Engine
if is_gradio_available():
import gradio as gr
def create_ui(demo_mode: bool = False) -> gr.Blocks: def create_ui(demo_mode: bool = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False) engine = Engine(demo_mode=demo_mode, pure_chat=False)

View File

@@ -4,9 +4,7 @@ import time
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator from typing import TYPE_CHECKING, Any, Dict, Generator
import gradio as gr
import transformers import transformers
from gradio.components import Component # cannot use TYPE_CHECKING here
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_cuda_available from transformers.utils import is_torch_cuda_available
@@ -14,13 +12,20 @@ from ..extras.callbacks import LogCallback
from ..extras.constants import TRAINING_STAGES from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_gradio_available
from ..train import run_exp from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component
from .manager import Manager from .manager import Manager
@@ -239,7 +244,7 @@ class Runner:
return args return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]: def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True) error = self._initialize(data, do_train, from_preview=True)
if error: if error:
@@ -249,7 +254,7 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield {output_box: gen_cmd(args)} yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]: def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False) error = self._initialize(data, do_train, from_preview=False)
if error: if error:
@@ -263,19 +268,19 @@ class Runner:
self.thread.start() self.thread.start()
yield from self.monitor() yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: def preview_train(self, data):
yield from self._preview(data, do_train=True) yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: def preview_eval(self, data):
yield from self._preview(data, do_train=False) yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: def run_train(self, data):
yield from self._launch(data, do_train=True) yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: def run_eval(self, data):
yield from self._launch(data, do_train=False) yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Dict[Component, Any], None, None]: def monitor(self):
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.aborted = False self.aborted = False
self.running = True self.running = True
@@ -332,7 +337,7 @@ class Runner:
yield return_dict yield return_dict
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]: def save_args(self, data):
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True) error = self._initialize(data, do_train=True, from_preview=True)
if error: if error:
@@ -351,7 +356,7 @@ class Runner:
save_path = save_args(config_path, config_dict) save_path = save_args(config_path, config_dict)
return {output_box: ALERTS["info_config_saved"][lang] + save_path} return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]: def load_args(self, lang: str, config_path: str):
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path) config_dict = load_args(config_path)
if config_dict is None: if config_dict is None:

View File

@@ -3,21 +3,24 @@ import os
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import gradio as gr from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth from ..extras.ploting import smooth
from .locales import ALERTS from .locales import ALERTS
if TYPE_CHECKING: if is_gradio_available():
from ..extras.callbacks import LogCallback import gradio as gr
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
def update_process_bar(callback: "LogCallback") -> "gr.Slider": def update_process_bar(callback: "LogCallback") -> "gr.Slider":
if not callback.max_steps: if not callback.max_steps:
return gr.Slider(visible=False) return gr.Slider(visible=False)