85 Commits

Author SHA1 Message Date
hiyouga
d5f1b99ac4 Release v0.1.6
Former-commit-id: 43c8b3c3c8bfb2e32d17fb3e8b194938e37d54bd
2023-08-11 23:25:57 +08:00
hiyouga
2144bb0e27 Update README_zh.md
Former-commit-id: 4fc154bcf039ba3f9240213158df757881cf3579
2023-08-11 14:06:02 +08:00
hiyouga
bc665bacc7 add defaults
Former-commit-id: 4636d3bbe6b984ca93e3a80ae5239f3ddda461bd
2023-08-11 13:56:26 +08:00
hiyouga
52bfcf4883 fix stop word in baichuan template
Former-commit-id: cba5ac9cfc81f11b97831998ea15def5e0b487c2
2023-08-11 13:51:46 +08:00
hiyouga
06df3d6fb6 fix baichuan template
Former-commit-id: b1681fe35346381cda613297f1cbb710f0a6daa6
2023-08-11 13:45:47 +08:00
hiyouga
ca719a8697 support DPO training (2305.18290)
Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
2023-08-11 03:02:53 +08:00
hoshi-hiyouga
72dfd74005 Merge pull request #451 from jovialchen/main
huggingface login for projects must login while running

Former-commit-id: 246ac241277908909b81cdf85fec1f24449dbae9
2023-08-10 17:25:38 +08:00
hiyouga
69302c4420 fix webui val size
Former-commit-id: 490c067d4e0828832e0ebdb704a9207dc974b15b
2023-08-10 15:20:44 +08:00
jiongxuc
42d7019b2e huggingface login for projects must login while running
Former-commit-id: 0a4a2a1d3e0ff1f57215512d294d782080bd383c
2023-08-10 14:57:12 +08:00
hiyouga
5f0d0d6b9b fix template
Former-commit-id: e3967eb1cdd8d19e8afee9ba52e7eb7d6cd86129
2023-08-09 23:14:27 +08:00
hiyouga
76cb63e4f6 fix template
Former-commit-id: 907e8cd86fbd4cdfa26dad21ceaf6e01d8fe37e4
2023-08-09 23:10:20 +08:00
hiyouga
467d571206 support val set in streaming mode
Former-commit-id: faed15b58ed00b1e09bb091e7eee48f5ef7c508b
2023-08-09 23:00:26 +08:00
hiyouga
972bfa700a fix tokenizer
Former-commit-id: 7849587cd4e149291d08edef9a528a1bad796c7e
2023-08-09 17:52:15 +08:00
hiyouga
990eeccf45 fix sft trainer
Former-commit-id: 08cc888b1569572d0cd20bcf3f07e20072a0311a
2023-08-09 16:35:03 +08:00
hiyouga
a3a7465f00 fix rm #420, fix template #426, fix #423
Former-commit-id: 70ea3caaa7a7695c77179cd1bb18707a80a373d7
2023-08-09 16:23:31 +08:00
hoshi-hiyouga
031a819257 fix llama2 template
Former-commit-id: 6c74f726d4e672f5a1a57df201c27c1f697384f0
2023-08-09 00:58:27 +08:00
hoshi-hiyouga
eb4b4e3c8c fix tokenizer
Former-commit-id: fa463ef279b596d5d53cc169831f51b42031fc05
2023-08-09 00:54:54 +08:00
hiyouga
d2e1fe9b1d update webui
Former-commit-id: 343a4cd82b07a40f96ba413d1d991419ff07a24a
2023-08-09 00:26:11 +08:00
hiyouga
6e27a9e39a fix tokenizer #417
Former-commit-id: 01aa678311bfd213a4b410a4e0ff09f48a0d40a1
2023-08-08 23:59:41 +08:00
hiyouga
805478c911 fix bug
Former-commit-id: 0dff1d951f1a9fe05a74d334bf477b55c7c64199
2023-08-08 21:28:28 +08:00
hiyouga
a281cdeb89 fix bug
Former-commit-id: c13ce66021b21e015871b84489eeafa127a424a4
2023-08-08 17:55:55 +08:00
hiyouga
cda698a67f fix chatml template #408
Former-commit-id: 21e0cc3f44c35ae689b00b274391492f413725ac
2023-08-08 17:44:39 +08:00
hiyouga
15acd17716 update args spec
Former-commit-id: a006068346edda6e2851b23d2005fdb218a7287d
2023-08-07 15:23:35 +08:00
hiyouga
34a2bddfcd update readme
Former-commit-id: 06bcbb901f69265632892a5fcbc956b8be1153da
2023-08-07 15:02:02 +08:00
hiyouga
370f817549 Merge branch 'main' of https://github.com/hiyouga/LLaMA-Efficient-Tuning
Former-commit-id: 5c5657227db285048e3850631badb040eea9b6ca
2023-08-07 13:59:16 +08:00
hiyouga
041390c37e fix #376
Former-commit-id: a5b01257ba3323bcb2dd0217fb89a387e39ddbec
2023-08-07 13:58:59 +08:00
hoshi-hiyouga
d9fe4bf500 Merge pull request #382 from hiyouga/feature-updateReadme
add detailed model configs

Former-commit-id: 371c50cf3fd4e3f5e8fb390508c27cb5f18fa531
2023-08-07 13:43:38 +08:00
hiyouga
e0c7e944fc update trainer
Former-commit-id: 0d39b53a5164e34d22fe0a492eaa0d7ac63102fe
2023-08-07 13:34:35 +08:00
codemayq
0845fe67db add detailed model configs
Former-commit-id: 438c43f820e39738eaa1c296aadcf6d141c3289f
2023-08-07 09:30:23 +08:00
hiyouga
fe3b12d900 fix qwen eos token
Former-commit-id: 770830c67886f5872b39b9608949ec62d4616b27
2023-08-06 13:31:17 +08:00
hiyouga
a70d56864e fix qwen tokenizer #361
Former-commit-id: 78a2fa95c8ab669254a6c8fce8138c4395fb0a09
2023-08-05 17:06:05 +08:00
hiyouga
fdbb2c5378 fix template for tiktoken
Former-commit-id: 8328447f81eb5b90310df08cf2928c83ef6355fe
2023-08-05 13:42:42 +08:00
hiyouga
3c0aaf42af remove redundant code
Former-commit-id: dcec1717592107ba9d26eb2ac520309da19d1805
2023-08-05 00:27:27 +08:00
hiyouga
438e19160a fix template
Former-commit-id: b88200a88ea112e043dc44058606805c60e32844
2023-08-05 00:25:00 +08:00
hiyouga
f2b2ff6950 fix llama2 template
Former-commit-id: 08f37145e0bca5f1a8fd7bad01c64dc69b07361b
2023-08-05 00:07:54 +08:00
hoshi-hiyouga
86cef96305 Support safe ChatML template, fix qwen tok #351 #354
https://github.com/openai/openai-python/blob/main/chatml.md
Former-commit-id: 94bfc9d85f7cef3a5eb15085e0124a424373814f
2023-08-05 00:00:23 +08:00
hiyouga
5f50944baf fix bos and eos token
Former-commit-id: ab386f4c0fb5eaac24264a5bbef4c03deeb92158
2023-08-04 23:55:57 +08:00
hiyouga
0804fd2353 fix encode
Former-commit-id: ec382abd906d93cf78c7fbaec753ce6bcf8cfebd
2023-08-04 23:27:55 +08:00
hiyouga
86419eb457 support chatml safe encoding
Former-commit-id: ea52bb135bf9d07738091006ec7ada8df14cf15e
2023-08-04 23:14:28 +08:00
hiyouga
76f3ae7bf3 support interleave probs
Former-commit-id: 168d99816f9bdc746c587f7f09753ba7e0a4b19d
2023-08-04 21:27:35 +08:00
hiyouga
aaa85190eb fix webui export model
Former-commit-id: c34469c05e681239db23e2e666b5ac6a4e38aba9
2023-08-04 14:20:27 +08:00
hiyouga
e2a4e926b9 fix mtloader
Former-commit-id: ca48c2c02c3cfa9afb99971b50daeda9cf14e7cb
2023-08-03 19:29:02 +08:00
hiyouga
d6e922dc1c tiny fix
Former-commit-id: 81ef7017a4c96441951adeff0276cc5ab76a3544
2023-08-03 17:42:28 +08:00
hiyouga
27f4317ec6 fix qwen inference
Former-commit-id: 823f0de0ca0a92b6f48a90e5ffe57a48dc018f1d
2023-08-03 16:31:55 +08:00
hiyouga
e434348216 fix qwen inference
Former-commit-id: 2c5fe45ce1405124f12ecd20e263b5538af97972
2023-08-03 16:15:38 +08:00
hiyouga
2e19afedb8 support Qwen-7B, fix InternLM-7B inference
Former-commit-id: 25d2ca29ecb70cbfd5206333c667042a0c4d2e5a
2023-08-03 15:53:32 +08:00
hiyouga
da08fa7c63 update web demo
Former-commit-id: 5b6ad9adb665096bfb36dc90789a1d4a16345122
2023-08-03 13:28:28 +08:00
hiyouga
9c96b97dc7 fix webui
Former-commit-id: e87630ef77977b2879f1199b9a421acbbbb32a51
2023-08-03 12:43:12 +08:00
hiyouga
28a51b622b modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
2023-08-02 23:17:36 +08:00
hiyouga
8bd1da7144 fix PPO trainer
Former-commit-id: 21982a7d4dd9b7c3a1145b481f02b9990e32dc00
2023-08-02 19:10:23 +08:00
hiyouga
e4d0b8ee6e update ppo trainer
Former-commit-id: c27136a83e167465d3f825e40f10c7b9fcfbf97a
2023-08-02 18:46:41 +08:00
hiyouga
1dfb28b362 fix memory leak of PPO trainer
Former-commit-id: 38410894a5ebf0b043b55a6bd5cca3cd0a44b27d
2023-08-02 17:41:34 +08:00
hiyouga
ba618947e7 release v0.1.5
Former-commit-id: d619e76bc4098c29a7fdc05f5a71208bd1079c9f
2023-08-02 16:10:31 +08:00
hoshi-hiyouga
f81041b502 Merge pull request #307 from GitYCC/feature/fix-llama2-prompt-template
[feature] Fix template of Llama2 to match the offical template

Former-commit-id: a750b1f1ed16e20233df4d2f1c20507122919f5a
2023-08-02 15:51:28 +08:00
YC Chen
f2533a2800 [fix] Remove useless code
Former-commit-id: 077e1556112913e4eeef47e581055183b39d5404
2023-08-02 14:35:35 +08:00
YC Chen
bb5b4a7f26 [feature] Fix template of Llama2 to match the offical template
Former-commit-id: 1a98d45aefd95eea3768fb93e5a9da257ec61181
2023-08-02 14:10:15 +08:00
hiyouga
20bff87021 fix bug in preprocessing
Former-commit-id: 94952894576dfc4b42118162aec9aa35c3503c40
2023-08-02 01:10:28 +08:00
hiyouga
722b954800 update readme
Former-commit-id: 5154a04869be8c47e591351565b7842339fb99e4
2023-08-01 18:48:27 +08:00
hiyouga
19256086c7 fix #296
Former-commit-id: 69e9ed9b96a7cfb3d3b43ec5ddd01aa0bfd9b784
2023-08-01 18:43:53 +08:00
hiyouga
250fecfcd4 Fix #294
Former-commit-id: 09762d9849655f5e6c71b9472d55b42489dd944b
2023-08-01 18:13:03 +08:00
hiyouga
cb4d1d5ebb restore from git lfs
Former-commit-id: 0c734a37113b773ae7c0bc8b8d1af39b15bc0fb2
2023-08-01 16:33:25 +08:00
hiyouga
d7d557fb2e Update .gitattributes
Former-commit-id: 92e68f9f30c2fc91ae1b40865bc5c2d94899ba22
2023-08-01 16:28:54 +08:00
hiyouga
0b8e19b6a6 fix webui
Former-commit-id: cf4cd52d36894f53a6ec45d003f887771012e5b4
2023-08-01 12:11:37 +08:00
hiyouga
8e26eb374e fix RM save model
Former-commit-id: 8104cc2425431eb1cddccf3909855296116f922b
2023-08-01 11:56:17 +08:00
hiyouga
9bba01a033 use git lfs
Former-commit-id: 4886d0071751f68c5a2d926bd9fcee0c93337322
2023-08-01 10:14:08 +08:00
hiyouga
661890b8a1 release v0.1.4
Former-commit-id: 81f84aaf2e120e39edb28ef42893939fc9a184e2
2023-08-01 10:08:47 +08:00
hiyouga
772ad4ec6b fix inference
Former-commit-id: 55dc2bdd3eaa552c655e584fc3cbbf017c7bc3e7
2023-08-01 00:06:48 +08:00
hiyouga
6f65f8cb3b fix arg check
Former-commit-id: 2c5c73de9ebc88e2d04e80754781c94a571133a0
2023-07-31 23:48:57 +08:00
hiyouga
43e83548b9 update readme
Former-commit-id: d99cda254e5025ff3f968d256197ab031bfabef1
2023-07-31 23:42:32 +08:00
hiyouga
dd3f3e9749 support streaming data, fix #284 #274 #268
Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
2023-07-31 23:33:00 +08:00
hiyouga
124f61b404 Update data_args.py
Former-commit-id: 41ac5455af195747ba369c3a6dc7d412a366d54d
2023-07-28 17:42:41 +08:00
hiyouga
e8748cc6f3 update readme
Former-commit-id: 14d20cd1fdcfd1f2842362f70472b666e5d48c7d
2023-07-28 17:36:00 +08:00
hiyouga
fafec8b7a5 fix #268
Former-commit-id: 1eee0207fb370bb9e234e9bd3f9a0c47d7d01bc9
2023-07-28 17:02:26 +08:00
hiyouga
030daca686 update dataset
Former-commit-id: 4a044aabbd19c92a9ae93c1c30536f5086fd47f9
2023-07-26 17:05:12 +08:00
hiyouga
ac587438f8 fix #242
Former-commit-id: 80a346e29beb49e8935b786e2af1059fdc4954b2
2023-07-25 17:04:02 +08:00
hiyouga
c145bbef3c update dataset
Former-commit-id: 4fc2c3293d91d8464527ebd1ddabe572c8355616
2023-07-23 20:01:43 +08:00
hiyouga
745c46ee04 Update README_zh.md
Former-commit-id: 9d3c8803a34c06a2a5512fec3f841d7efcab3e3c
2023-07-22 14:31:16 +08:00
hiyouga
a707f5b502 update readme, fix web ui postprocess
Former-commit-id: ba51ab3379100108f7b52a3c2444ccdd99e8a6ef
2023-07-22 14:29:22 +08:00
hoshi-hiyouga
dc2e801077 Merge pull request #221 from mrhan1993/main
根据GLM Efficient Tuning添加中文README,web添加了server_port参数

Former-commit-id: 948f2abcb818211ee99d4c140e26044ca591369f
2023-07-22 13:04:25 +08:00
NULL
b56d5108b2 Merge branch 'hiyouga:main' into main
Former-commit-id: 8244c5b554ad0823e8ebea3d5583a6ecf9a66d2d
2023-07-21 17:00:26 +08:00
mrhan1993
8e6b7034fe 根据GLM Efficient Tuning添加中文README,web添加了server_port
Former-commit-id: 29e3acd23eafd891667d7a860ec544a5b05d3c33
2023-07-21 16:57:58 +08:00
hiyouga
dad7ca6633 release v0.1.3
Former-commit-id: 62c68bcbf591516e8f90b47810bea6f710fd23f6
2023-07-21 16:48:34 +08:00
hiyouga
a1468139a5 fix save function
Former-commit-id: 1d6beb0c8490a7531ffdf7a2819410597b200d12
2023-07-21 14:09:07 +08:00
hiyouga
49c90044ce Update runner.py
Former-commit-id: d7309deae46cfcdeeee79f54736df9b7e93b79ce
2023-07-21 13:35:19 +08:00
hiyouga
0f7cdac207 update web UI, support rm predict #210
Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
2023-07-21 13:27:27 +08:00
67 changed files with 2475 additions and 899 deletions

143
README.md
View File

@@ -8,17 +8,27 @@
👋 Join our [WeChat](assets/wechat.jpg).
\[ English | [中文](README_zh.md) \]
## Changelog
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
@@ -28,39 +38,46 @@
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
- [InternLM](https://github.com/InternLM/InternLM) (7B)
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- |----------|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
| ---------------------- | -------------- | ----------------- | ---- | ----- |
| Pre-Training | ✅ | ✅ | ✅ | ✅ |
| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
| Reward Model Training | | | ✅ | ✅ |
| PPO Training | | | ✅ | ✅ |
| DPO Training | ✅ | | ✅ | ✅ |
## Provided Datasets
- For pre-training:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- For supervised fine-tuning:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
@@ -68,7 +85,6 @@
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -77,12 +93,13 @@
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward modelling:
- For reward modelling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -100,18 +117,13 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- sentencepiece and tiktoken
- jieba, rouge-chinese and nltk (used at evaluation)
- gradio and matplotlib (used in web_demo.py)
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
And **powerful GPUs**!
If you want to enable quantized LoRA (QLoRA) on the Windows platform, you should install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
## Getting Started
### Data Preparation (optional)
@@ -130,13 +142,21 @@ cd LLaMA-Efficient-Tuning
pip install -r requirements.txt
```
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### All-in-one Web UI
```bash
python src/train_web.py
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
### (Continually) Pre-Training
Currently the web UI only supports training on **a single GPU**.
### Pre-Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -144,6 +164,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset wiki_demo \
--template default \
--finetuning_type lora \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
@@ -166,6 +187,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
@@ -180,6 +202,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
Remember to specify `--lora_target W_pack` if you are using Baichuan models.
### Reward Model Training
```bash
@@ -188,9 +212,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
@@ -201,7 +228,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
### PPO Training (RLHF)
### PPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -209,7 +236,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
@@ -220,10 +249,33 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--resume_lora_training False \
--plot_loss
```
### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### Distributed Training
```bash
@@ -267,6 +319,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_eval \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
@@ -285,6 +338,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_predict \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
@@ -293,13 +347,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
If you want to predict the samples with empty responses, please kindly fill the `response` column with **dummy tokens** to ensure the sample will not be discarded throughout the preprocessing phase.
### API Demo
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@@ -311,6 +364,7 @@ Visit `http://localhost:8000/docs` for API documentation.
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@@ -320,6 +374,7 @@ python src/cli_demo.py \
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@@ -329,11 +384,18 @@ python src/web_demo.py \
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
## TODO
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
- [ ] Implementing multi-query attention for faster inference.
- [ ] Supporting full-parameter RLHF training.
## License
This repository is licensed under the [Apache-2.0 License](LICENSE).
@@ -344,8 +406,11 @@ Please follow the model licenses to use the corresponding model weights:
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## Citation

431
README_zh.md Normal file
View File

@@ -0,0 +1,431 @@
# LLaMA Efficient Tuning
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Efficient-Tuning?style=social)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Efficient-Tuning)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
👋 加入我们的[微信群](assets/wechat.jpg)。
\[ [English](README.md) | 中文 \]
## 更新日志
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型请添加 `--template chatml` 参数。
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数。
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数。
## 模型
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- |----------|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
- 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用对应的模板。
## 训练方法
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------- | ---------- | ----------- | ---- | ----- |
| 预训练 | ✅ | ✅ | ✅ | ✅ |
| 指令监督微调 | ✅ | ✅ | ✅ | ✅ |
| 奖励模型训练 | | | ✅ | ✅ |
| PPO 训练 | | | ✅ | ✅ |
| DPO 训练 | ✅ | | ✅ | ✅ |
## 数据集
- 用于预训练:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
使用方法请参考 [data/README.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
```bash
pip install --upgrade huggingface_hub
huggingface-cli login
```
## 软件依赖
- Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- sentencepiece 和 tiktoken
- jieba, rouge-chinese 和 nltk (用于评估)
- gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API)
以及 **强而有力的 GPU**
## 如何使用
### 数据准备(可跳过)
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`
### 环境搭建(可跳过)
```bash
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
cd LLaMA-Efficient-Tuning
pip install -r requirements.txt
```
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### 浏览器一键微调/测试
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
目前网页 UI 仅支持**单卡训练**。
### 预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--model_name_or_path path_to_your_model \
--do_train \
--dataset wiki_demo \
--template default \
--finetuning_type lora \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
### 指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。
### 奖励模型训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
```
### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### 多 GPU 分布式训练
```bash
accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 Accelerate 配置示例</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
### 指标评估BLEU分数和汉语ROUGE分数
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_eval \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128` 参数。
### 模型预测
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_predict \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
### API 服务
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
关于 API 文档请见 `http://localhost:8000/docs`
### 命令行测试
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 浏览器测试
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 导出微调模型
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
## TODO
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
- [ ] 支持 RLHF 的全参数微调。
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
## 引用
如果您觉得此项目有帮助,请考虑以下列格式引用
```bibtex
@Misc{llama-efficient-tuning,
title = {LLaMA Efficient Tuning},
author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
year = {2023}
}
```
## 致谢
本项目是 [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) 的同类项目。采用了类似的代码结构和训练方法。
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Efficient-Tuning&type=Date)

View File

@@ -2,16 +2,16 @@ If you are using a custom dataset, please provide your dataset definition in the
```json
"dataset_name": {
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
}
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
}
}
```

18
data/README_zh.md Normal file
View File

@@ -0,0 +1,18 @@
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以如下格式提供您的数据集定义。
```json
"数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址若指定则忽略下列三个参数",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None"
}
}
```
其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。

View File

@@ -1 +1 @@
0a57fbc1d8cb08a8cd71c5eb8425cf59206ffed6
57fd080be5bffe4153fe3ee26a175e3d56da30f3

View File

@@ -1 +0,0 @@
56405bb8f52727e52e99693739494b9b7b0d7ba6

View File

@@ -1 +0,0 @@
fa935248a5d40d2bdd5649af99a72a754d40ae7a

View File

@@ -1,10 +1,12 @@
torch>=1.13.1
transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.19.0
peft>=0.3.0
trl>=0.4.7
accelerate>=0.21.0
peft>=0.4.0
trl>=0.5.0
scipy
sentencepiece
tiktoken
jieba
rouge-chinese
nltk

View File

@@ -1,19 +1,13 @@
# coding=utf-8
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
import uvicorn
from llmtuner import ChatModel
from llmtuner.api.app import create_app
from llmtuner.tuner import get_infer_args
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel(*get_infer_args())
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
print("Visit http://localhost:8000/docs for API document.")
if __name__ == "__main__":

View File

@@ -1,13 +1,8 @@
# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel
from llmtuner.tuner import get_infer_args
def main():
chat_model = ChatModel(*get_infer_args())
chat_model = ChatModel()
history = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

View File

@@ -1,16 +1,8 @@
# coding=utf-8
# Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
from llmtuner import export_model
def main():
model_args, _, training_args, finetuning_args, _ = get_train_args()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
export_model()
if __name__ == "__main__":

View File

@@ -1,4 +1,9 @@
# Level: api, webui > chat > tuner > dsets > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel
from llmtuner.tuner import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.1.2"
__version__ = "0.1.6"

View File

@@ -0,0 +1 @@
from llmtuner.api.app import create_app

View File

@@ -5,9 +5,8 @@ from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple
from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
@@ -50,8 +49,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content
@@ -122,6 +121,6 @@ def create_app(chat_model: ChatModel) -> FastAPI:
if __name__ == "__main__":
chat_model = ChatModel(*get_infer_args())
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@@ -1,37 +1,41 @@
import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from transformers import PreTrainedModel, TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import get_template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
class ChatModel:
def __init__(
self,
model_args: ModelArguments,
data_args: DataArguments,
finetuning_args: FinetuningArguments,
generating_args: GeneratingArguments
) -> None:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.template = get_template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix or ""
self.generating_args = generating_args
self.model = dispatch_model(self.model)
self.model = self.model.eval() # change to eval mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
)
input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
@@ -41,12 +45,14 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
input_ids=input_ids,
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
logits_processor=get_logits_processor(),
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
))
if max_length:
@@ -61,7 +67,11 @@ class ChatModel:
@torch.inference_mode()
def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
@@ -72,7 +82,11 @@ class ChatModel:
@torch.inference_mode()
def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

View File

@@ -1,40 +1,50 @@
import os
import hashlib
from typing import List
from typing import TYPE_CHECKING, List, Optional
from datasets import Dataset, concatenate_datasets, load_dataset
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, DataArguments
if TYPE_CHECKING:
from datasets import Dataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset(
model_args: ModelArguments,
data_args: DataArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
ext2type = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
model_args: "ModelArguments",
data_args: "DataArguments"
) -> "Dataset":
max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets
all_datasets: List["Dataset"] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
@@ -47,60 +57,62 @@ def get_dataset(
data_path = None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = ext2type.get(data_files[0].split(".")[-1], None)
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = ext2type.get(data_files[0].split(".")[-1], None)
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
checksum(data_files[0], dataset_attr.dataset_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
raw_datasets = load_dataset(
dataset = load_dataset(
data_path,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None
)
dataset = raw_datasets[data_args.split]
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
prefix_data = [dataset_attr.source_prefix] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
("query_column", "query"),
("response_column", "response"),
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
dataset = dataset.add_column("prefix", prefix_data)
for column_name in ["prompt", "query", "response", "history"]: # align datasets
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix
if data_args.streaming:
features = dataset.features
features["prefix"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
else:
prefix_data = [dataset_attr.source_prefix] * len(dataset)
dataset = dataset.add_column("prefix", prefix_data)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0]
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
else:
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
raise ValueError("Unknown mixing strategy.")

View File

@@ -1,91 +1,88 @@
from typing import Literal
import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
from itertools import chain
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template
from llmtuner.hparams import DataArguments
from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
dataset: "Dataset",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:
) -> "Dataset":
column_names = list(dataset.column_names)
prompt_template = get_template(data_args.prompt_template)
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
# support question with a single answer or multiple answers
def get_dialog(examples):
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
yield dialog
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
prefix = examples["prefix"][i] if "prefix" in examples else None
yield query, response, history, prefix
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
block_size = data_args.max_source_length - 1
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.max_source_length
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
for i in range(0, total_length, block_size)]
return {
"input_ids": result,
"labels": result.copy()
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
def preprocess_supervised_dataset(examples):
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []}
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length
for dialog in get_dialog(examples):
for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], []
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix):
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_unsupervised_dataset(examples):
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
for query, response, history, prefix in construct_example(examples):
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@@ -93,80 +90,82 @@ def preprocess_dataset(
target_ids = target_ids[:data_args.max_target_length]
model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, prefix in construct_example(examples):
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
if len(prompt_ids) > data_args.max_source_length:
prompt_ids = prompt_ids[:data_args.max_source_length]
if len(chosen_ids) > data_args.max_target_length:
chosen_ids = chosen_ids[:data_args.max_target_length]
if len(rejected_ids) > data_args.max_target_length:
rejected_ids = rejected_ids[:data_args.max_target_length]
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
skip_special_tokens=False)
))
print("labels:\n{}".format(tokenizer.decode([
token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"]
], skip_special_tokens=False)))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
elif stage == "sft":
preprocess_function = preprocess_unsupervised_dataset \
if training_args.predict_with_generate else preprocess_supervised_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
batched=True,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
**kwargs
)
if stage == "pt":
print_unsupervised_dataset_example(dataset[0])
elif stage == "sft":
print_supervised_dataset_example(dataset[0])
elif stage == "rm":
print_pairwise_dataset_example(dataset[0])
elif stage == "ppo":
print_unsupervised_dataset_example(dataset[0])
print_function(next(iter(dataset)))
return dataset

View File

@@ -1,16 +1,30 @@
from typing import Dict
from datasets import Dataset
from typing import TYPE_CHECKING, Dict, Union
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
def split_dataset(
dataset: Dataset, dev_ratio: float, do_train: bool
) -> Dict[str, Dataset]:
# Split the dataset
if do_train:
if dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@@ -1,74 +1,128 @@
import os
import json
import time
from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from transformers import TrainerCallback
from transformers.trainer_utils import has_length
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
logger = get_logger(__name__)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
self.start_time = time.time()
self.tracker = {}
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def timing(self):
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
self.start_time = time.time()
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
Event called at the end of training.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.runner is not None and self.runner.aborted:
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
if not state.is_world_process_zero:
if not state.is_local_process_zero:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
self.tracker = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"eval_loss": state.log_history[-1].get("eval_loss", None),
"predict_loss": state.log_history[-1].get("predict_loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n")
f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()

View File

@@ -1,10 +1,12 @@
IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"]
@@ -25,15 +27,19 @@ SUPPORTED_MODELS = {
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b",
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
"InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B",
"ChatGLM2-6B": "THUDM/chatglm2-6b"
}
DEFAULT_MODULE = {
@@ -43,5 +49,8 @@ DEFAULT_MODULE = {
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"InternLM": "q_proj,v_proj"
"InternLM": "q_proj,v_proj",
"Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value"
}

View File

@@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n"
def get_logger(name: str) -> logging.Logger:
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"

View File

@@ -1,12 +1,17 @@
import torch
from typing import List, Optional
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import (
LogitsProcessor,
LogitsProcessorList,
StoppingCriteria,
StoppingCriteriaList
)
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
class AverageMeter:
r"""
@@ -28,7 +33,6 @@ class AverageMeter:
self.avg = self.sum / self.count
# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -44,29 +48,53 @@ def get_logits_processor() -> LogitsProcessorList:
return logits_processor
def print_trainable_params(model: torch.nn.Module) -> None:
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_ids: List[int]) -> None:
super().__init__()
self.stop_ids = stop_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopWordsCriteria(stop_ids))
return stopping_criteria
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
model: "PreTrainedModel",
finetuning_type: str,
output_embedding_layer_name: Optional[str] = "lm_head",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel:
) -> "PreTrainedModel":
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
@@ -83,19 +111,23 @@ def prepare_model_for_training(
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype
if finetuning_type != "full" and hasattr(model, output_layer_name):
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model
def torch_gc() -> None:
r"""
Collects GPU memory.
@@ -103,3 +135,25 @@ def torch_gc() -> None:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()

View File

@@ -12,8 +12,8 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
state_dict = model.state_dict()
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():

View File

@@ -1,64 +1,213 @@
from typing import Dict, List, Optional, Tuple
import tiktoken
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
@dataclass
class Template:
prefix: str
prompt: str
sep: str
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> str:
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a string containing prompt without response.
Returns a single pair of token ids representing prompt and response respectively.
"""
return "".join(self._format_example(query, history, prefix))
prefix, history = self._format(query, resp, history, prefix)
encoded_pairs = self._encode(tokenizer, prefix, history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
return prompt_ids, encoded_pairs[-1][1]
def get_dialog(
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._format_example(query, history, prefix) + [resp]
prefix, history = self._format(query, resp, history, prefix)
encoded_pairs = self._encode(tokenizer, prefix, history)
return encoded_pairs
def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
def _format(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
r"""
Aligns inputs to the standard format.
"""
prefix = [prefix] if prefix else self.prefix # use prefix if provided
history = history if (history and self.use_history) else []
history = history + [(query, "<dummy>")]
convs = []
for turn_idx, (user_query, bot_resp) in enumerate(history):
history = history + [(query, resp)]
return prefix, history
def _get_special_ids(
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id:
bos_ids = [tokenizer.bos_token_id]
else:
bos_ids = [] # bos token is optional
if tokenizer.eos_token_id:
eos_ids = [tokenizer.eos_token_id]
else:
raise ValueError("EOS token is required.")
return bos_ids, eos_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]],
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
convs.append(prefix + self.prompt.format(query=user_query))
convs.append(bot_resp)
if prefix: # has prefix
prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids
else:
prefix_ids = bos_ids
else:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp)
return convs[:-1] # drop last
prefix_ids = sep_ids + bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
query: Optional[str] = "",
idx: Optional[str] = ""
) -> List[int]:
r"""
Converts context to token ids.
"""
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
token_ids = []
for elem in context:
if isinstance(elem, str):
elem = elem.replace("{{query}}", query, 1)
elem = elem.replace("{{idx}}", idx, 1)
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise NotImplementedError
return token_ids
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]],
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string."
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has not sep_ids
query = prefix[0] + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Template(
def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
sep: List[Union[str, Dict[str, str]]],
stop_words: List[str],
use_history: bool
) -> None:
template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
sep=sep,
stop_words=stop_words,
use_history=use_history
)
def get_template(name: str) -> Template:
def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
if len(template.stop_words): # inplace method
tokenizer.eos_token = template.stop_words[0]
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
return template
@@ -67,9 +216,12 @@ Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix="",
prompt="{query}",
sep="",
prefix=[],
prompt=[
"{{query}}"
],
sep=[],
stop_words=[],
use_history=False
)
@@ -79,10 +231,17 @@ Default template.
"""
register_template(
name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
prefix=[
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
],
prompt=[
"Human: {{query}}\nAssistant: "
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@@ -94,16 +253,38 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
"""
register_template(
name="llama2",
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt=" [INST] {query} [/INST] ",
sep="</s>",
prefix=[
"<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
sep=[],
stop_words=[],
use_history=True
)
r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
"""
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
sep=[],
stop_words=[],
use_history=True
)
@@ -114,10 +295,17 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
"""
register_template(
name="alpaca",
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
prefix=[
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
sep=[
"\n\n"
],
stop_words=[],
use_history=True
)
@@ -128,10 +316,15 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
"""
register_template(
name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
prefix=[
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
],
prompt=[
"USER: {{query}} ASSISTANT: "
],
sep=[],
stop_words=[],
use_history=True
)
@@ -141,9 +334,14 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
prefix=[],
prompt=[
"Human: {{query}}\n\nBelle: "
],
sep=[
"\n\n"
],
stop_words=[],
use_history=True
)
@@ -153,9 +351,14 @@ Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
prefix=[],
prompt=[
"User: {{query}}\nBot: "
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@@ -165,9 +368,14 @@ Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
prefix=[],
prompt=[
"Human: {{query}}\nAssistant: "
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@@ -177,9 +385,17 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
prefix=[],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
sep=[
"\n"
],
stop_words=[],
use_history=True
)
@@ -189,10 +405,17 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_template(
name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
prefix=[
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
],
prompt=[
"Human: {{query}}###Assistant: "
],
sep=[
"###"
],
stop_words=[],
use_history=True
)
@@ -202,9 +425,18 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
prefix=[],
prompt=[
"<|User|>:{{query}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
sep=[
"\n"
],
stop_words=[
"<eoa>"
],
use_history=True
)
@@ -214,9 +446,17 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="</s>",
prefix=[
{"token": "<reserved_102>"} # user token (a little difference in the first turn)
],
prompt=[
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
sep=[],
stop_words=[
"<reserved_102>" # user token
],
use_history=True
)
@@ -227,8 +467,71 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
prefix=[
{"token": "<|system|>"},
"\n",
{"token": "<|end|>"}
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
sep=[
"\n"
],
stop_words=[
"<|end|>"
],
use_history=True
)
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
"""
register_template(
name="chatml",
prefix=[
{"token": "<|im_start|>"},
"system\nYou are a helpful assistant.",
{"token": "<|im_end|>"}
],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
sep=[
"\n"
],
stop_words=[
"<|im_end|>"
],
use_history=True
)
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"}
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
sep=[
"\n\n"
],
stop_words=[],
use_history=True
)

View File

@@ -1,6 +1,6 @@
import os
import json
from typing import List, Optional
from typing import List, Literal, Optional
from dataclasses import dataclass, field
@@ -16,19 +16,22 @@ class DatasetAttr:
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass
class DataArguments:
"""
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: str = field(
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field(
default="alpaca_zh",
default="alpaca_en",
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
@@ -39,6 +42,22 @@ class DataArguments:
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
)
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
@@ -71,13 +90,9 @@ class DataArguments:
default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
)
dev_ratio: Optional[float] = field(
val_size: Optional[float] = field(
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
)
prompt_template: Optional[str] = field(
default="default",
metadata={"help": "Which template to use for constructing prompts in training and inference."}
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
)
def init_for_training(self): # support mixing multiple datasets
@@ -92,6 +107,9 @@ class DataArguments:
else:
prefix_list = [None] * len(dataset_names)
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names):
if name not in dataset_info:
@@ -111,9 +129,9 @@ class DataArguments:
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
self.dataset_list.append(dataset_attr)

View File

@@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field
@dataclass
class FinetuningArguments:
"""
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
@@ -14,23 +14,27 @@ class FinetuningArguments:
)
num_hidden_layers: Optional[int] = field(
default=32,
metadata={"help": "Number of decoder blocks in the model. \
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \
Falcon choices: [\"32\", \"60\"], \
Baichuan choices: [\"32\", \"40\"]"}
Baichuan choices: [\"32\", \"40\"] \
Qwen choices: [\"32\"], \
XVERSE choices: [\"40\"]"}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"}
Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"mlp\", \"attn\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
lora_rank: Optional[int] = field(
default=8,
@@ -47,9 +51,19 @@ class FinetuningArguments:
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
def __post_init__(self):
@@ -66,14 +80,14 @@ class FinetuningArguments:
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
"""Saves the content of this instance in JSON format inside `json_path`."""
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
"""Creates an instance from the content of `json_path`."""
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@@ -4,10 +4,10 @@ from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
r"""
Arguments pertaining to which stage we are going to perform.
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)

View File

@@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
@dataclass
class GeneratingArguments:
"""
r"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass
class ModelArguments:
"""
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
@@ -55,14 +55,14 @@ class ModelArguments:
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_auth_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
def __post_init__(self):
if self.checkpoint_dir is not None: # support merging multiple lora weights
@@ -70,3 +70,7 @@ class ModelArguments:
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_auth_token is not None:
from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)

View File

@@ -1,5 +1 @@
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.tune import export_model, run_exp

View File

@@ -1,7 +1,7 @@
import os
import torch
from typing import TYPE_CHECKING
from transformers.modeling_utils import PreTrainedModel
from peft import (
PeftModel,
TaskType,
@@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
def init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
is_mergeable: bool
) -> PreTrainedModel:
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -36,7 +39,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full":
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
@@ -62,7 +65,7 @@ def init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir

View File

@@ -1,43 +1,48 @@
import os
import torch
from typing import Literal, Optional, Tuple
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
def load_model_and_tokenizer(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
@@ -47,9 +52,6 @@ def load_model_and_tokenizer(
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
@@ -63,10 +65,13 @@ def load_model_and_tokenizer(
padding_side=model_args.padding_side,
**config_kwargs
)
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
tokenizer.pad_token_id = 0 # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
is_mergeable = True
# Quantization configurations (using bitsandbytes library).
@@ -74,16 +79,10 @@ def load_model_and_tokenizer(
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
@@ -96,14 +95,6 @@ def load_model_and_tokenizer(
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if not is_trainable: # `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto"
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
# Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
@@ -126,7 +117,8 @@ def load_model_and_tokenizer(
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
@@ -137,8 +129,6 @@ def load_model_and_tokenizer(
})
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
@@ -147,6 +137,9 @@ def load_model_and_tokenizer(
model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model)
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
return model, tokenizer

View File

@@ -19,20 +19,66 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
))
return _parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
))
return _parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
if args is not None:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@@ -48,24 +94,42 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
assert (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stage can only be performed at training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
raise ValueError("Streaming mode should have an integer val size.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@@ -73,13 +137,18 @@ def get_train_args(
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training.")
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
@@ -91,44 +160,36 @@ def get_train_args(
model_args.compute_dtype = torch.float32
# Log on each process the small summary:
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), training_args.fp16
))
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, general_args
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
if args is not None:
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
return model_args, data_args, finetuning_args, generating_args

View File

@@ -1,34 +1,37 @@
import os
import torch
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer):
class PeftModelMixin:
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
"""
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self._remove_log()
def _remove_log(self):
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def __init__(self) -> None: # for type checking
self.model: PreTrainedModel = None
self.tokenizer: "PreTrainedTokenizer" = None
self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
self.state: "TrainerState" = None
raise AssertionError("Mixin should not be initialized.")
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
@@ -41,29 +44,36 @@ class PeftTrainer(Seq2SeqTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()
v_head_state_dict = {
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.")
}
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
backbone_model = getattr(model, "pretrained_model")
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
model = model.pretrained_model
state_dict = state_dict or get_state_dict(model)
if isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
model.config.use_cache = False
else:
backbone_model = model
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if self.finetuning_args.finetuning_type == "lora":
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model),
safe_serialization=self.args.save_safetensors
)
backbone_model.config.use_cache = False
if self.tokenizer is not None:
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
try:
self.tokenizer.save_pretrained(output_dir)
except:
logger.warning("Cannot save tokenizer, copy the files manually.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
@@ -73,16 +83,25 @@ class PeftTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if self.finetuning_args.finetuning_type == "lora":
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if isinstance(model, PreTrainedModelWrapper):
model.v_head.load_state_dict(torch.load(
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
))
model = model.pretrained_model
if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args

View File

@@ -0,0 +1 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@@ -0,0 +1,51 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@@ -0,0 +1,75 @@
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import Trainer
from trl import DPOTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
**kwargs
):
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if ref_model is not None:
if hasattr(self, "accelerator"):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
else:
raise AttributeError("Please update `transformers`.")
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits: torch.Tensor = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits: torch.Tensor = model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logps = self._get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits

View File

@@ -0,0 +1,59 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
training_args.remove_unused_columns = False # important for pairwise dataset
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
generating_args=generating_args,
ref_model=ref_model,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

View File

@@ -2,21 +2,24 @@ import os
import math
import torch
from tqdm import tqdm
from typing import Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel
from transformers import TrainerState, TrainerControl
from trl import PPOTrainer
from trl.core import LengthSampler
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor
from llmtuner.hparams import FinetuningArguments
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
@@ -25,21 +28,22 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: List[LogCallback],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"],
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0]
self.state = TrainerState()
self.control = TrainerControl()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
self._remove_log()
def ppo_train(self, max_target_length: int) -> None:
r"""
@@ -66,19 +70,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate`
gen_kwargs = {
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
gen_kwargs = self.generating_args.to_dict()
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
@@ -86,51 +86,38 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
batch = next(dataiter)
steps_trained += 1
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
# Get responses
query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
rewards = self.get_rewards(queries, responses, unwrapped_model)
queries, responses = [], []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default")
# Run PPO step
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
# Run PPO step
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
)
print(logs)
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control)
@@ -147,38 +134,57 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
@torch.no_grad()
def generate(
def get_inputs(
self,
inputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None,
return_prompt: Optional[bool] = True,
**generation_kwargs
) -> torch.Tensor:
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
Subclass and override to inject custom behavior.
"""
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
unwrapped_model = self.accelerator.unwrap_model(self.model)
response = unwrapped_model.generate(**inputs, **generation_kwargs)
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
queries, responses = [], []
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response
return queries, responses
@torch.no_grad()
def get_rewards(
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead"
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
"""
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""

View File

@@ -1,11 +1,13 @@
import torch
from typing import Dict, List, Literal, Optional, Tuple
from trl import AutoModelForCausalLMWithValueHead
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
valuehead_state_dict = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
@@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead,
model: "AutoModelForCausalLMWithValueHead",
layer_norm_names: List[str] = LAYERNORM_NAMES,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
layer_norm_state_dict = {}

View File

@@ -1,27 +1,29 @@
# Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from trl import PPOConfig
from torch.optim import AdamW
from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
@@ -38,20 +40,23 @@ def run_ppo(
max_grad_norm=training_args.max_grad_norm
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
total_train_batch_size = \
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps
)
# Initialize our Trainer
ppo_trainer = PPOPeftTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks,
config=ppo_config,
model=model,
@@ -63,8 +68,10 @@ def run_ppo(
lr_scheduler=lr_scheduler
)
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])
# Training
if training_args.do_train:
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -1,32 +1,30 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
import math
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
trainer = PeftTrainer(
@@ -36,7 +34,7 @@ def run_pt(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Training

View File

@@ -1,8 +1,10 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
@@ -15,5 +17,11 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
features = [
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
}
for key in ("chosen_ids", "rejected_ids") for feature in features
]
return super().__call__(features)

View File

@@ -1,9 +1,18 @@
import os
import json
import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer):
r"""
@@ -16,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
def compute_loss(
self,
model: PreTrainedModel,
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -36,3 +45,26 @@ class PairwisePeftTrainer(PeftTrainer):
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
def save_predictions(
self,
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
acc_scores, rej_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
writer.write("\n".join(res))

View File

@@ -2,25 +2,26 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
@@ -38,7 +39,7 @@ def run_rm(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Training
@@ -56,3 +57,10 @@ def run_rm(
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)

View File

@@ -1,7 +1,6 @@
import numpy as np
from dataclasses import dataclass
from typing import Dict, Sequence, Tuple, Union
from transformers.tokenization_utils import PreTrainedTokenizer
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba
from rouge_chinese import Rouge
@@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
@dataclass
class ComputeMetrics:
@@ -16,7 +18,7 @@ class ComputeMetrics:
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
"""
tokenizer: PreTrainedTokenizer
tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r"""

View File

@@ -3,13 +3,15 @@ import json
import torch
import numpy as np
import torch.nn as nn
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
logger = get_logger(__name__)
@@ -77,11 +79,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(
self,
predict_results: PredictionOutput
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.

View File

@@ -1,25 +1,28 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
@@ -44,17 +47,13 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Keyword arguments for `model.generate`
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_new_tokens": data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
gen_kwargs = generating_args.to_dict()
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
# Training
if training_args.do_train:

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
try:
tokenizer.save_pretrained(training_args.output_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
run_exp()

View File

@@ -0,0 +1 @@
from llmtuner.webui.interface import create_ui, create_web_demo

View File

@@ -1,22 +1,22 @@
import os
from typing import List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self, *args):
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if len(args) != 0:
super().__init__(*args)
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
if lazy_init:
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
else:
super().__init__(args)
def load_model(
self,
@@ -54,10 +54,10 @@ class WebChatModel(ChatModel):
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
template=template,
source_prefix=source_prefix
)
super().__init__(*get_infer_args(args))
super().__init__(args)
yield ALERTS["info_loaded"][lang]
@@ -84,6 +84,14 @@ class WebChatModel(ChatModel):
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
response = self.postprocess(response)
new_history = history + [(query, response)]
chatbot[-1] = [query, response]
yield chatbot, new_history
def postprocess(self, response: str) -> str:
blocks = response.split("```")
for i, block in enumerate(blocks):
if i % 2 == 0:
blocks[i] = block.replace("<", "&lt;").replace(">", "&gt;")
return "```".join(blocks)

View File

@@ -1,4 +1,6 @@
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box

View File

@@ -1,16 +1,17 @@
from typing import Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel,
chat_model: "WebChatModel",
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
@@ -22,13 +23,9 @@ def create_chat_box(
with gr.Column(scale=1):
clear_btn = gr.Button()
max_new_tokens = gr.Slider(
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True
)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
history = gr.State([])

View File

@@ -1,10 +1,12 @@
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)
@@ -14,6 +16,6 @@ def create_preview_box() -> Tuple[Block, Component, Component, Component]:
close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return preview_box, preview_count, preview_samples, close_btn

View File

@@ -1,14 +1,16 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
@@ -18,7 +20,12 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -31,6 +38,9 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
output_box = gr.Markdown()
@@ -52,7 +62,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
batch_size,
predict
],
[output_box]
[
output_box,
process_bar
]
)
stop_btn.click(runner.set_abort, queue=False)

View File

@@ -0,0 +1,37 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from llmtuner.webui.utils import save_model
if TYPE_CHECKING:
from gradio.components import Component
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
save_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["template"],
max_shard_size,
save_dir
],
[info_box]
)
return dict(
save_dir=save_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box
)

View File

@@ -1,13 +1,15 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()

View File

@@ -1,16 +1,18 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
@@ -20,7 +22,12 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -36,7 +43,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
)
max_grad_norm = gr.Textbox(value="1.0")
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
@@ -44,20 +51,26 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
padding_side = gr.Radio(choices=["left", "right"], value="left")
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True)
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
output_box = gr.Markdown()
@@ -86,21 +99,26 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
gradient_accumulation_steps,
lr_scheduler_type,
max_grad_norm,
dev_ratio,
val_size,
logging_steps,
save_steps,
warmup_steps,
compute_type,
padding_side,
lora_rank,
lora_dropout,
lora_target,
resume_lora_training,
output_dir
],
[output_box]
[
output_box,
process_bar
]
)
stop_btn.click(runner.set_abort, queue=False)
output_box.change(
process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
@@ -120,16 +138,18 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
dev_ratio=dev_ratio,
val_size=val_size,
advanced_tab=advanced_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
compute_type=compute_type,
padding_side=padding_side,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,

View File

@@ -1,15 +1,17 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, Component]:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
@@ -41,7 +43,7 @@ def create_top() -> Dict[str, Component]:
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False)
return dict(
lang=lang,

View File

@@ -5,8 +5,11 @@ from llmtuner.webui.components import (
create_top,
create_sft_tab,
create_eval_tab,
create_infer_tab
create_infer_tab,
create_export_tab,
create_chat_box
)
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
@@ -30,7 +33,10 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)
demo.load(
@@ -43,12 +49,30 @@ def create_ui() -> gr.Blocks:
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
queue=False
)
return demo
def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)

View File

@@ -227,9 +227,9 @@ LOCALES = {
"info": "用于梯度裁剪的范数。"
}
},
"dev_ratio": {
"val_size": {
"en": {
"label": "Dev ratio",
"label": "Val size",
"info": "Proportion of data in the dev set."
},
"zh": {
@@ -277,6 +277,16 @@ LOCALES = {
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
},
"padding_side": {
"en": {
"label": "Padding side",
"info": "The side on which the model should have padding applied."
},
"zh": {
"label": "填充位置",
"info": "使用左填充或右填充。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
@@ -315,6 +325,16 @@ LOCALES = {
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
}
},
"resume_lora_training": {
"en": {
"label": "Resume LoRA training",
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
},
"zh": {
"label": "继续上次的训练",
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
}
},
"start_btn": {
"en": {
"value": "Start"
@@ -452,6 +472,34 @@ LOCALES = {
"zh": {
"label": "温度系数"
}
},
"save_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "模型文件的最大大小。"
}
},
"export_btn": {
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
}
}
@@ -477,6 +525,18 @@ ALERTS = {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"err_no_checkpoint": {
"en": "Please select a checkpoint.",
"zh": "请选择断点。"
},
"err_no_save_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"err_failed": {
"en": "Failed.",
"zh": "训练出错。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
@@ -504,5 +564,13 @@ ALERTS = {
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
}
}

View File

@@ -1,6 +1,6 @@
import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component
from typing import Any, Dict, List
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
@@ -24,7 +24,7 @@ class Manager:
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]:
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {}
refresh_dict = self.gen_refresh()

View File

@@ -1,18 +1,20 @@
import gradio as gr
import logging
import os
import threading
import time
import transformers
from typing import List, Optional, Tuple
from transformers.trainer import TRAINING_ARGS_NAME
from typing import Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
from llmtuner.webui.utils import get_eval_results, update_process_bar
class Runner:
@@ -25,7 +27,9 @@ class Runner:
self.aborted = True
self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
def initialize(
self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
@@ -50,13 +54,15 @@ class Runner:
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
def finalize(
self, lang: str, finish_info: str
) -> str:
self.running = False
torch_gc()
if self.aborted:
return ALERTS["info_aborted"][lang]
else:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
return finish_info
def run_train(
self,
@@ -78,19 +84,21 @@ class Runner:
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
dev_ratio: float,
val_size: float,
logging_steps: int,
save_steps: int,
warmup_steps: int,
compute_type: str,
padding_side: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
output_dir: str
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
yield error, gr.update(visible=False)
return
if checkpoints:
@@ -100,14 +108,17 @@ class Runner:
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
do_train=True,
overwrite_cache=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
@@ -125,38 +136,37 @@ class Runner:
warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"),
bf16=(compute_type == "bf16"),
padding_side=padding_side,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
resume_lora_training=resume_lora_training,
output_dir=output_dir
)
if dev_ratio > 1e-6:
args["dev_ratio"] = dev_ratio
if val_size > 1e-6:
args["val_size"] = val_size
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(1)
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang]
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield logger_handler.log, update_process_bar(trainer_callback)
yield self.finalize(lang)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info), gr.update(visible=False)
def run_eval(
self,
@@ -174,10 +184,10 @@ class Runner:
max_samples: str,
batch_size: int,
predict: bool
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
yield error, gr.update(visible=False)
return
if checkpoints:
@@ -190,6 +200,7 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
do_eval=True,
overwrite_cache=True,
@@ -197,7 +208,7 @@ class Runner:
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
@@ -212,23 +223,20 @@ class Runner:
args.pop("do_eval", None)
args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(1)
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang]
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield logger_handler.log, update_process_bar(trainer_callback)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info), gr.update(visible=False)

View File

@@ -3,20 +3,30 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Any, Dict, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, tracker: dict) -> str:
info = log
if "current_steps" in tracker:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
)
return info
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps:
return gr.update(visible=False)
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps,
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
)
return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str:
@@ -83,3 +93,46 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def save_model(
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
template: str,
max_shard_size: int,
save_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return
if not checkpoints:
yield ALERTS["err_no_checkpoint"][lang]
return
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
if not save_dir:
yield ALERTS["err_no_save_dir"][lang]
return
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
template=template,
output_dir=save_dir
)
yield ALERTS["info_exporting"][lang]
export_model(args, max_shard_size="{}GB".format(max_shard_size))
yield ALERTS["info_exported"][lang]

View File

@@ -1,17 +1,8 @@
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
from llmtuner import run_exp
def main():
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args)
run_exp()
def _mp_fn(index):

View File

@@ -1,10 +1,10 @@
from llmtuner.webui.interface import create_ui
from llmtuner import create_ui
def main():
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
if __name__ == "__main__":

View File

@@ -1,35 +1,10 @@
# coding=utf-8
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner.tuner import get_infer_args
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
from llmtuner import create_web_demo
def main():
chat_model = WebChatModel(*get_infer_args())
with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
demo = create_web_demo()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
if __name__ == "__main__":