283 Commits

Author SHA1 Message Date
hiyouga
48d2e6d7fe tiny fix and release
Former-commit-id: 79ae5f2e06c151cd8f71a96a5ee099f034043ffd
2024-02-29 00:46:47 +08:00
hoshi-hiyouga
041c83ea03 Merge pull request #2575 from lungothrin/feature/chatter-with-role
support on fly test of tools

Former-commit-id: c49af47d97ef2bae2c57dd03333752321ad6d483
2024-02-29 00:39:47 +08:00
hiyouga
0e621c2dc9 fix #2629
Former-commit-id: c18822669568327d4fbf480a80c5fe5b8fc95e7a
2024-02-29 00:37:29 +08:00
hiyouga
544e7a491b release v0.5.3
Former-commit-id: f6bc89581b3cd129448da2defc23848de6f494ed
2024-02-29 00:34:19 +08:00
hiyouga
a2c881fa08 add examples
Former-commit-id: 8cdf64adc2c8e5f194a6df26cf749d7bc9bc039f
2024-02-28 23:19:25 +08:00
hiyouga
c53c7af168 update chatglm3 template
Former-commit-id: f55e75ef3b86ea7930bb9d84b46cfc953a74441d
2024-02-28 21:11:23 +08:00
hiyouga
a2d93e5269 update readme
Former-commit-id: 654f3e174a460c621c52724b69fc4aee93370970
2024-02-28 20:50:01 +08:00
hiyouga
b392e6cfb9 support DoRA, AWQ, AQLM #2512
Former-commit-id: 6614cc1f08aa944db083e27e451bbdd733f7dd97
2024-02-28 19:53:28 +08:00
Liang Ge
13aa2d389a support on fly test of tools
Former-commit-id: 95bb82fd89512ea13caf20850d1f46d8a62b4e2a
2024-02-28 01:17:49 +08:00
hoshi-hiyouga
1e7962dfc4 Merge pull request #2608 from Katehuuh/main
bump accelerate

Former-commit-id: 315662bac17c2e958d0e0b706c6e3443b8a11ec8
2024-02-27 16:49:34 +08:00
Katehuuh
1c9556c84c bump accelerate
Former-commit-id: 100deec5a8b025dbf60cf543775d2b136a75eef4
2024-02-27 08:56:45 +01:00
hiyouga
ca3ca7a5b5 add pr template
Former-commit-id: 3303855fb08316c78bf2959e3fdd6de389a1e486
2024-02-26 18:31:07 +08:00
hoshi-hiyouga
0500befdb4 Create CONTRIBUTING.md
Former-commit-id: 892ae9fd570c1c9e307ecb1fd861b8de59f2a835
2024-02-26 18:23:03 +08:00
hoshi-hiyouga
f618feab51 Create SECURITY.md
Former-commit-id: c7459a8eac77dbfbae910d468e4ac04acd9fd9de
2024-02-26 18:03:17 +08:00
hiyouga
4b06aa134f update readme
Former-commit-id: 1b1b427ea13d2a84683514d924555db974865d73
2024-02-26 17:25:47 +08:00
hoshi-hiyouga
9cde56d760 Merge pull request #2531 from Rayrtfr/main
Support Atom Model

Former-commit-id: 9868d3e85d70413e49e108297309fcc62a5c1567
2024-02-26 16:36:45 +08:00
Rayrtfr
d0ea203694 Support Atom Model
Former-commit-id: da3e76f22aca9acaf772ff821b7eb03c2a2ac869
2024-02-26 10:44:10 +08:00
hiyouga
c5eb3fba62 update webui
Former-commit-id: 298a5fc52610deb9f7d555e2fc699f10067d8af5
2024-02-25 20:23:41 +08:00
hiyouga
a8bc32553c update readme
Former-commit-id: 33c93b1e89f532073429156dac45b62542d34070
2024-02-25 16:26:08 +08:00
hoshi-hiyouga
88f3358320 Merge pull request #2525 from stephen-nju/main
update project_kwargs for ppo config

Former-commit-id: e7a6910141cc8d8dd966c1f54388d9ef764418d0
2024-02-25 15:54:00 +08:00
hiyouga
a85bdcf2f6 add papers
Former-commit-id: d1650cddf66b2d118d618eff2f6beb082000a0e4
2024-02-25 15:34:47 +08:00
hiyouga
caf56b313e add papers
Former-commit-id: edf0af7bfc4d621a59be782e57b55c0e878e5b4a
2024-02-25 15:18:58 +08:00
hiyouga
75603c45fc fix data entry
Former-commit-id: e5c116816f2d00e3bfe1a9be5886fe1e41d93212
2024-02-23 18:29:24 +08:00
hiyouga
89f86cc970 fix gemma template
Former-commit-id: 75950d115845e00318bd457e66440e2c2d98efbd
2024-02-23 13:49:53 +08:00
hiyouga
c09a0e4f08 fix template
Former-commit-id: 84673463221f2b359732de8a936a8e7ca1d003b6
2024-02-22 12:09:21 +08:00
hiyouga
7bac6c9460 fix template
Former-commit-id: 1737c7389264ef80bb8ba85c73ede0b0381e11f9
2024-02-22 12:06:48 +08:00
hiyouga
0c7d0bf172 support gemma
Former-commit-id: b9674aa2f6f1b6b09b2a37375313d8d5abfcd453
2024-02-21 23:27:36 +08:00
hiyouga
a274900188 fix #2532
Former-commit-id: 23a8e64f1c47cd473c627effbe271233c136369c
2024-02-21 21:55:14 +08:00
hiyouga
67deefe527 tiny fix
Former-commit-id: acc99ef2fb62908288f88369354135d581588b63
2024-02-21 18:30:29 +08:00
stephen
823f618cba update project_kwargs for ppo config
Former-commit-id: 14f106962fc0a87802ae9ecffff00d52f7f5f046
2024-02-21 13:47:38 +08:00
hiyouga
bc16c9a54a support lora for llama pro
Former-commit-id: f74c78ba95f0545aae89e603e466f494705ad024
2024-02-21 02:17:22 +08:00
hiyouga
a3f30038a0 fix #2516
Former-commit-id: ce2340193e751c4212650b27f16c671261015047
2024-02-20 20:44:24 +08:00
hoshi-hiyouga
e237f618c2 Merge pull request #2514 from codemayq/main
add a pre-built version of flash-attn

Former-commit-id: 2521f1c7bd39dff17de90650ddb5167f66f27940
2024-02-20 16:09:25 +08:00
hoshi-hiyouga
688adad665 Update README.md
Former-commit-id: 8a7a02fcba077778a84164a16ff2cf33ec813dc4
2024-02-20 16:07:55 +08:00
hoshi-hiyouga
0158812afb Update README_zh.md
Former-commit-id: 4c3310651b67bbea8c893d503de2b5736184daaf
2024-02-20 16:06:59 +08:00
codemayq
e52e0d9b07 1. update the version of pre-built bitsandbytes library
2. add pre-built flash-attn library


Former-commit-id: 2b76a300995a74398ee11d9274e5c0eb6ef53403
2024-02-20 11:28:25 +08:00
codemayq
eb2aa2c073 1. update the version of pre-built bitsandbytes library
2. add pre-built flash-attn library


Former-commit-id: 9b40eddf7aeb6b3bcf58374d43cbe44eb24f3849
2024-02-20 11:26:22 +08:00
hiyouga
debfd46749 release v0.5.2
Former-commit-id: 0189867816b0eab92fb2a1b5f1b1da079bd161a7
2024-02-20 11:12:43 +08:00
hiyouga
5ccf8fcd6b update webui
Former-commit-id: 9e0f7c362d40b78d57e77d52eaa96e678cebadcd
2024-02-19 16:49:58 +08:00
hiyouga
7bd1991513 add test scripts
Former-commit-id: fdaa4843961257b48cc32d83d30f2efe18b9fd5a
2024-02-19 02:09:13 +08:00
hiyouga
456e4ca569 fix safetensors
Former-commit-id: 06478ae5302d5fc6eb7afedc69335ce2f32808c6
2024-02-18 18:12:16 +08:00
hiyouga
6bf0fe4913 fix #2481
Former-commit-id: 2a4e3e4a26a2fad77ccc476be7d45434b8af4a55
2024-02-15 19:07:47 +08:00
hiyouga
596b6828cb support llama pro #2338 , add rslora
Former-commit-id: 40d659b7f30dd5a004703c176ec1f22dc864e505
2024-02-15 02:27:36 +08:00
hoshi-hiyouga
b403f8d8a8 Merge pull request #2474 from younesbelkada/add-hf-tags
FEAT: add HF tags for models that have been trained with llama-factory
Former-commit-id: f35d96817e61da9fa7789b93b0350c9f95afc40a
2024-02-14 10:26:03 +08:00
younesbelkada
590b6c2143 add v1 hf tags
Former-commit-id: a29cc9f4472c95cd6a43ea350ab728e0a8069c6e
2024-02-13 05:58:49 +00:00
hiyouga
5537ef1e7d fix #2471
Former-commit-id: a408be8be1cf99cd4468a9905c27ec454f312b9a
2024-02-12 21:07:46 +08:00
hiyouga
5f83860aa1 add option to disable version check
Former-commit-id: fd769cb2de696aee3c5e882237e16eace6a9d675
2024-02-10 22:31:23 +08:00
hiyouga
62b6a7971a update data/readme
Former-commit-id: aa566e3cea5bc75688b4399a9da07be0b35b921c
2024-02-10 21:04:29 +08:00
hiyouga
1d16e87c5f update default template
Former-commit-id: f32b55649a9f95109a6d180216eb67f959d060da
2024-02-10 16:44:47 +08:00
hiyouga
1955a8ea5a improve aligner
Former-commit-id: cc7296b92e10c24967fc753393275b71d300683f
2024-02-10 16:39:19 +08:00
hoshi-hiyouga
a41fa6e730 Merge pull request #2462 from mnmueller/main
Enable Parsing of SlimOrca

Former-commit-id: 99eed520b87152ca6b89c2a068b09200fd45f30d
2024-02-09 22:55:48 +08:00
hiyouga
b98a64448a improve fix tokenizer
Former-commit-id: 57b138abad6397596bc47be94e092e8fabedc06f
2024-02-09 14:53:14 +08:00
Mark Mueller
1ce82f391a Slim Orca data parsing
Former-commit-id: f2d8efede7e20edafed0d5446eb64f2d419949b1
2024-02-08 19:32:20 +01:00
Mark Mueller
4d473894fd Slim Orca data parsing
Former-commit-id: ca57d27c39d4e7bc3dd7c3207a23d23d2cbd446b
2024-02-08 17:56:18 +01:00
Mark Mueller
5788b7c7d0 Slim Orca data parsing
Former-commit-id: 3016427be4e63fd25f40bc5a0d1f8cedc0997334
2024-02-08 17:54:18 +01:00
Mark Mueller
04515f6b55 Slim Orca data parsing
Former-commit-id: 4dca3907964d27abc2b21eb55c75676901c98912
2024-02-08 17:52:36 +01:00
Mark Mueller
96f8ccf3d5 SlimOrca aligner
Former-commit-id: 928dda93867c2327a7957c04648592044ccf9daf
2024-02-08 08:28:32 -08:00
hoshi-hiyouga
2c3ef480a6 Merge pull request #2423 from mayflower/main
Support for german sft and dpo

Former-commit-id: 8e282e4e6bee6493b1bd38ba239ca49a6a840a92
2024-02-07 15:58:20 +08:00
hiyouga
fa6873122c Update tests.yml
Former-commit-id: c882b7cf339304ff16a36b1544a3b5f1194ef346
2024-02-07 01:18:22 +08:00
hiyouga
34bc0c22b1 lint
Former-commit-id: 6b1f89b6494e9b6b087fe90600617a3024e014e5
2024-02-07 01:10:04 +08:00
hiyouga
e5484b2729 Update pyproject.toml
Former-commit-id: 650251ea77fae2e2595ca804f49efdd230dbb5b1
2024-02-07 00:45:58 +08:00
hiyouga
f67f781fed update gc kwargs
Former-commit-id: 0cb81c156bc8c21a4bbdd3289a491f78dfcaf730
2024-02-07 00:38:24 +08:00
hiyouga
b564b97b7e fix #2438
Former-commit-id: 412d856eeada2abcea598fac0a8d35ae90cc9c01
2024-02-06 15:23:08 +08:00
hiyouga
0dd68d1e06 add models
Former-commit-id: 0fdf61b2f765c125acda4f406eb25b3e59e75db2
2024-02-06 14:57:23 +08:00
hiyouga
73f40f1ca4 support qwen1.5
Former-commit-id: 8a03a572b058c5cc4ff598670dc8595b2b97e374
2024-02-06 00:10:51 +08:00
hoshi-hiyouga
ea53bebac4 fix #2436
Update test_toolcall.py

Former-commit-id: 39c539b6470c532ac639efbd2a1c485d2f5d485f
2024-02-05 22:55:28 +08:00
hoshi-hiyouga
00418012bd Update test_toolcall.py
Former-commit-id: f50a684a9d6fc2351436d3d7020dc84bc1553a5d
2024-02-05 22:51:03 +08:00
hoshi-hiyouga
5f3d8c514b Update test_toolcall.py
Former-commit-id: 97bcae546ab80737a906e5e28953f41b657f6c99
2024-02-05 22:50:43 +08:00
tao.jun
cb39a3f1c4 Update test_toolcall.py
Add openai version notes

Former-commit-id: 9ea4ab214e64f73ec902e76b82fc42419571fd66
2024-02-05 20:49:23 +08:00
Johann-Peter Hartmann
4d78fe6ece Merge branch 'hiyouga:main' into main
Former-commit-id: efbb0153981d0650f3a581e324b83054ca8063c1
2024-02-04 13:55:00 +00:00
hiyouga
a3e3ea9846 fix #2421
Former-commit-id: 43918c12310f7560d3820e5c6d72964309afeb8b
2024-02-04 21:02:55 +08:00
Johann-Peter Hartmann
feba34e82d Merge branch 'hiyouga:main' into main
Former-commit-id: 0395d0aafb69e86645e6b0a36b8f8dadb82219e0
2024-02-04 12:51:25 +00:00
hiyouga
e134013e04 fix reserved label len
Former-commit-id: b06d6c05a1911f329252a7572240048e456affdc
2024-02-04 17:54:26 +08:00
hiyouga
5589d0296a fix #2420
Former-commit-id: 7a34087e4db62e603c9a9a26d8ff3910d7b10c40
2024-02-04 15:51:47 +08:00
hiyouga
de0ebab464 fix #2189
Former-commit-id: b3d81b229d376671e1c12229aeb487b0d84f2548
2024-02-04 00:47:37 +08:00
hiyouga
f2e7122a96 bump up transformers version
Former-commit-id: 82f4d4301ed9f31b160d6313a1d2d44a22865f4d
2024-02-04 00:01:16 +08:00
hiyouga
996cc5d900 fix #2397
Former-commit-id: 7404692808f2288d539668d364965ad104dacadb
2024-02-03 23:45:31 +08:00
hiyouga
a2ae5bd867 add hint for freeze #2412
Former-commit-id: 9600c93633629605573d908019563fa3870ad6f8
2024-02-03 23:38:56 +08:00
hiyouga
5fa52e87cb fix #2376
Former-commit-id: 8e2cfa7cca21b7fd4538d72114e36f704bcc82fe
2024-02-03 23:14:31 +08:00
hiyouga
bcd76d2c7a support minicpm #2404
Former-commit-id: 4449e91cbee8fd804cf8bf1ff6b9f5301fde94ed
2024-02-03 22:36:46 +08:00
Johann-Peter Hartmann
36fcbedc11 add simple german chatml template chatml_de
Former-commit-id: 9f1d67c09f1af2c7aa383adec66842cacde92e33
2024-02-03 09:01:15 +01:00
Johann-Peter Hartmann
1dad01cc53 Merge branch 'hiyouga:main' into main
Former-commit-id: c350237d891df7edd7e681f9da5ac1446fdeb568
2024-02-03 08:43:12 +01:00
hoshi-hiyouga
5fb21f6e54 Merge pull request #2411 from lxsyz/main
fix eos_token_id=0 bug

Former-commit-id: 019a353e74ec70a9a2d8987df1ed19483413211a
2024-02-02 17:38:16 +08:00
Fallen Angel
08dfac8352 fix eos_token_id=0 bug
when eos_token_id=0, will never add eos_token

Former-commit-id: 576b4881c386d897462a875b28066ce9d6e06dd5
2024-02-02 17:34:48 +08:00
Johann-Peter Hartmann
956751e419 Merge branch 'hiyouga:main' into main
Former-commit-id: 25b0a11c715f87812edba1ca14d3122a75f421de
2024-01-31 14:05:52 +01:00
hiyouga
fe2ae04c91 fix #2388
Former-commit-id: 203a36c9adfd9aa0f35fbf8089c9138534d68c53
2024-01-31 17:23:56 +08:00
hiyouga
5b8712d061 fix autoset attn impl, update data readme
Former-commit-id: 34a6e5f82baf45cc8dbb11f9f7ab4a480ab7ec5c
2024-01-31 11:58:07 +08:00
Johann-Peter Hartmann
dc7ff90c1e Add support for german datasets
Former-commit-id: bbc038aa236952597e97d1ccf1ae2d64a16339b5
2024-01-30 10:18:01 +01:00
hiyouga
1ace676170 fix #2320
Former-commit-id: e0b0c4415aaf80e75f6dd4f3777a0616b0e60f84
2024-01-24 16:19:18 +08:00
hoshi-hiyouga
8947a87b95 Merge pull request #2319 from ftgreat/main
Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3

Former-commit-id: 0fadcd5f9deb9f03d341b6611c15f337f07e32d1
2024-01-24 15:32:26 +08:00
ldwang
786a2f1103 Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.
Signed-off-by: ldwang <ftgreat@gmail.com>

Former-commit-id: 5f50c02f0e425737cd80abdf8fde9e25abf13083
2024-01-24 15:25:31 +08:00
ldwang
36ac14a566 Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.
Signed-off-by: ldwang <ftgreat@gmail.com>

Former-commit-id: d1413dcec8a3b1d671f240b82a689c72b54d7b93
2024-01-24 14:43:16 +08:00
hiyouga
7a048fc91d add hint
Former-commit-id: c540ef41bda61993b83ef8cfe3c84b1d169e984c
2024-01-22 23:32:01 +08:00
hoshi-hiyouga
3f3756b113 Merge pull request #2283 from A-Cepheus/main
fix: ZeRO3 does not work with MoE models
Former-commit-id: f5ea760abec2aac8d29ce5c945647be05648e676
2024-01-22 23:28:45 +08:00
hoshi-hiyouga
b36c4b99cc Update patcher.py
Former-commit-id: 33556cc6b0b65cc6db02e66f4f6e75112c33d966
2024-01-22 23:27:39 +08:00
hoshi-hiyouga
9856a2276e Update tests.yml
Former-commit-id: 34151675388701afa40220729a63b0d7b2fa2a7c
2024-01-22 23:22:15 +08:00
hoshi-hiyouga
b6dc3ed3ad Create tests.yml
Former-commit-id: 9443ad76b7ef3ef1f3d184ef60652947d2c30609
2024-01-22 23:13:04 +08:00
hiyouga
75be329994 fix #2282 and update tool prompt
Former-commit-id: 1c412f803866bde32b76f7c26c7b464b6b3651f3
2024-01-22 22:27:30 +08:00
hiyouga
1fe1ca1c8b add orion models
Former-commit-id: a34db89d2a281d1a1ace29dfd5bd5d4ff7c2f657
2024-01-22 21:26:53 +08:00
A-Cepheus
882a6a1d51 🐞 fix: typo
Former-commit-id: 57a3687ecd23237559aee0e8e811b782846f2415
2024-01-22 16:04:39 +08:00
A-Cepheus
712ab4ae7a 🐞 fix: typo, move MoE fix to patcher
Former-commit-id: 4ff28e99ff9b48df7150591c6bbd3723f22b7715
2024-01-22 16:01:58 +08:00
A-Cepheus
18ad259fb3 fix: ZeRO3 does not work with MoE models
Former-commit-id: b2844c049a88ea89f8e1812e2d2e8662b4002965
2024-01-22 15:21:14 +08:00
hiyouga
fe4d93c6db add array param format
Former-commit-id: bf910f8a5b21ee552fa9ab069610a3f5f611de57
2024-01-21 22:17:48 +08:00
hiyouga
c6ba588e37 update tool test
Former-commit-id: 1d63ccc2866632596310235de15fdff660f6bee5
2024-01-21 19:41:46 +08:00
hiyouga
3fda60fca0 fix api
Former-commit-id: cca004da28aaaa0788eaea62b83d3402b38a3011
2024-01-21 19:15:27 +08:00
hiyouga
96531a0ef8 fix #2268
Former-commit-id: 300ecf9b9d7fd99fbb68f3d086e3ad973c2f894e
2024-01-21 14:11:38 +08:00
hiyouga
7abc3065fb tiny fix
Former-commit-id: 66839ae94985ddfa13eca4542127119c919b9648
2024-01-21 13:26:12 +08:00
hoshi-hiyouga
013ded4bac Merge pull request #2266 from yhyu13/fix_export_model_dtype
Remove manully set use_cache; torch_dtype is not str, save model as b…

Former-commit-id: 8c0827ba92a458e18c3b68af0330af3a65149f96
2024-01-21 12:40:39 +08:00
hoshi-hiyouga
010c3c7348 Merge branch 'main' into fix_export_model_dtype
Former-commit-id: 6c7d2729f28eb37a97820d73c05648eb7fb2ca87
2024-01-21 12:40:24 +08:00
hoshi-hiyouga
bf075c075c Update tuner.py
Former-commit-id: 691420661f7115f809e76484c1f29f74637e7cd0
2024-01-21 12:39:38 +08:00
hoshi-hiyouga
41b34e5f60 Merge pull request #2262 from fenglui/main
fix torch_dtype check of export_model

Former-commit-id: 37cacf73a534fed1b06b4f3c6724f3568ce095e3
2024-01-21 12:34:37 +08:00
hiyouga
5a889398e7 format
Former-commit-id: f28a1a0c1cdd0062ad7b6c2826f8ec107a200cff
2024-01-21 12:34:17 +08:00
hoshi-hiyouga
054cae86d8 Merge pull request #2264 from seoeaa/main
add russian lang

Former-commit-id: 15d1747de54efe69ed9f4cfd8f296fe8dd09a5c9
2024-01-21 12:25:24 +08:00
yhyu13
cd1cb8b83c Remove manully set use_cache; torch_dtype is not str, save model as bfloat16 used to fail;
Former-commit-id: 75557fb5df16fd6eda7586cf041a16822dcfee8e
2024-01-21 11:12:15 +08:00
Aleksandr
a34779c027 add russian lang
Former-commit-id: f8ce6d75b56439027bb17ff4e59eeb9eb3b9bd34
2024-01-21 04:28:14 +03:00
fenglui
d19cb77d74 fix torch_dtype check of export_model
Former-commit-id: 8813181b6bffa76e5c7cb1f4caceada611c54b9d
2024-01-21 05:01:53 +08:00
hiyouga
ab67528e89 release v0.5.0 (real)
Former-commit-id: 2146e1d9195c179fa8f92144ec2b7034e1a9f942
2024-01-21 01:54:49 +08:00
hiyouga
27f281480a finish agent
Former-commit-id: d8d9d3afe32725fe79120fcd1a0970fdcdc45625
2024-01-21 01:47:33 +08:00
hiyouga
50459a39f4 fix api
Former-commit-id: a4149fbcd600d4f3815f9353e5e92c569719bed6
2024-01-21 00:03:09 +08:00
hiyouga
5c9815ef6f fix internlm2 template
Former-commit-id: ae05b23eb86555dbfafc174aa6ceff736e7fc9fa
2024-01-20 23:33:50 +08:00
hiyouga
aed00a97b6 fix cli_demo
Former-commit-id: e8336b3653f43618cf7cd70f8da004208de970c0
2024-01-20 23:27:10 +08:00
hiyouga
7543dc4a9d fix #2260
Former-commit-id: ba97550671811a27177306dd231bb427130b26fb
2024-01-20 23:22:09 +08:00
hiyouga
841fa0030f release v0.5.0
Former-commit-id: 602bb9b685009b9af234499be278404721542ac7
2024-01-20 20:21:39 +08:00
hiyouga
66e0e651b9 format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
2024-01-20 20:15:56 +08:00
hiyouga
1750218057 fix tests
Former-commit-id: 23f97bd437424ef43b2b84743d56acc5d1ca70d5
2024-01-20 19:58:04 +08:00
hiyouga
80637fc06d support longlora for main branch
Former-commit-id: f869501ad4c368df26534c41f62c6d63c6be17dd
2024-01-20 19:25:22 +08:00
hoshi-hiyouga
8efc055511 Merge pull request #2201 from liu-zichen/token_embed_resize
support resize embed for zero3

Former-commit-id: c0d1b5e3aef70da6b115614bd1ed539a76d6547a
2024-01-20 17:45:38 +08:00
hiyouga
be61bfda93 add upcast_lmhead option
Former-commit-id: 7ef69a1697c11ff13e7503360e40ef36cfb1c345
2024-01-19 23:54:25 +08:00
hiyouga
1a39f529c0 set use_reentrant=False
Former-commit-id: efa2e27d5ef6eaeb7baa7551c651ef10ab31400c
2024-01-19 23:29:54 +08:00
hiyouga
0868d5c550 fix #2249
Former-commit-id: 7ec64588c541422875adfdaf5692a27d05b96cb9
2024-01-19 21:44:32 +08:00
hiyouga
384f0e7678 add bf16 lora option
Former-commit-id: 58e7d7ff0cf9bf30e53b3eb12576f38d31976413
2024-01-19 16:29:03 +08:00
hiyouga
9b390c4bea fix function formatter
Former-commit-id: 363a87376ad8fe4149b387f7ccd60f31f2a5fdf7
2024-01-18 16:01:07 +08:00
hiyouga
42a13fec46 Update tuner.py
Former-commit-id: db30107385f100f88c9370abea6692ce6030a0c9
2024-01-18 15:06:02 +08:00
hiyouga
790acc4c17 fix templates
Former-commit-id: 382cc48b2a823b9a7d4ccf2c2a163f0e5b6e3169
2024-01-18 14:49:52 +08:00
hiyouga
b74cf27538 fix rm dataset
Former-commit-id: fa6f810026a59cecce813a696b2fdf15ba502fc4
2024-01-18 14:45:37 +08:00
hiyouga
ffc874ec6f fix pretrain data loader
Former-commit-id: 2a812b706ecc527013e79edc504ec18a4193123d
2024-01-18 14:42:52 +08:00
hoshi-hiyouga
546d6bd0b2 Merge pull request #2226 from hiyouga/dev
support function calling

Former-commit-id: 69391464f0d3fb0e2ef76e6b6fac51c119d66b53
2024-01-18 14:31:28 +08:00
hiyouga
8b68ca029e update readme
Former-commit-id: 11e0c732c4968b083f60a0bb6f7bb5dd5ca2ba56
2024-01-18 14:30:48 +08:00
hiyouga
502f84b30c add tool hint
Former-commit-id: 64734ffe2f45f80a1e33c2a72330b2ab1e58feb3
2024-01-18 13:19:09 +08:00
hiyouga
b7df920860 fix dataset
Former-commit-id: a7ce244a6d83d62f5bbecc588f1978e3791fd3b3
2024-01-18 12:59:30 +08:00
hiyouga
e4a424cb6a enable cutoff len
Former-commit-id: e9513d300c338dfcae98eee7d057bfd00da2da0e
2024-01-18 12:25:42 +08:00
hiyouga
d8affd3967 add tool test
Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
2024-01-18 10:26:26 +08:00
hiyouga
a423274fd9 support function calling
Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
2024-01-18 09:54:23 +08:00
hiyouga
f7329b1a0e Update llamafy_internlm2.py
Former-commit-id: 3ca5915a4fcd3d28d10a47bf9f2188b5cf8393a8
2024-01-18 01:12:31 +08:00
hiyouga
48eb07c956 Update llamafy_internlm2.py
Former-commit-id: 69b3cb768eda57b63f47cd35e5da3a59b57b7853
2024-01-18 01:00:16 +08:00
hiyouga
636d8a886c Update llamafy_internlm2.py
Former-commit-id: 1f1a7bcee5a5bb0fa17b13aa6393bfba89451dd7
2024-01-18 00:49:31 +08:00
hiyouga
97b52c7fdf fix llamafy scripts
Former-commit-id: 99ff69c36767d4397a4a61e89317ec8c0c295c1e
2024-01-18 00:37:37 +08:00
hiyouga
344412e66e fix llamafy_internlm2
Former-commit-id: a309375d020dedc313f3b6921fb53d932f156e8b
2024-01-18 00:26:14 +08:00
hiyouga
5cdea14cdf add llamafy_internlm2
Former-commit-id: 7b71767ef67cd5f246f52fb7e74b36bd26774a6c
2024-01-18 00:17:41 +08:00
hiyouga
7b1a56b96f support export push_to_hub #2183
Former-commit-id: fac09da7123a500d255de74810a8d057fb5c5f07
2024-01-16 23:59:42 +08:00
hiyouga
d1ec884e75 fix #2195
Former-commit-id: 801f7279693a0c785480ea67d663d99f4ca653da
2024-01-16 23:53:50 +08:00
liuzc
aa72a4349e support resize embed for zero3
Former-commit-id: b5464f5699b13bb118ac57ebc40b3cf9eb030396
2024-01-16 15:16:20 +08:00
hiyouga
5ab7fd0842 tiny fix
Former-commit-id: 6b1e9207e988c253a808e6bb26e3af9d071b77bc
2024-01-15 23:34:23 +08:00
hoshi-hiyouga
86d5e9802a Merge pull request #2194 from junuMoon/patch-1
fix: typo on README.md
Former-commit-id: a066a633a1a4b50cd6dc6b50701e35532fe788c1
2024-01-15 20:21:28 +08:00
Junu Moon(Fran)
18df39e3a1 fix: typo on README.md
Former-commit-id: 372066b559305a1428c88fbd6b01e332bfd5e3e1
2024-01-15 19:50:35 +09:00
hiyouga
cfe1e24471 support solar 10.7B #1907
Former-commit-id: ecf9b35c612e5514dd25b0d15835d28447a7437e
2024-01-14 00:30:30 +08:00
hiyouga
2edbe87a8c Update README_zh.md
Former-commit-id: e6d704c383e36abe8e27b3834f41d95890858425
2024-01-14 00:17:28 +08:00
hiyouga
880055bc90 support deepseek moe
Former-commit-id: 07fbb32496b9b81c4cfe67cb9a15a6b2c43852c3
2024-01-14 00:14:49 +08:00
hiyouga
ad99bd0a14 fix phi modules
Former-commit-id: 68d7e925ec51b6ee277513de8f61ac18a8378b98
2024-01-13 23:12:47 +08:00
hiyouga
c5f099138d fix #2147
Former-commit-id: 49445a03cd46af4e7036cf444cd041dfab2d8941
2024-01-12 03:30:56 +08:00
hiyouga
6e64e02f71 fix #2164
Former-commit-id: abe23bb4aca4fa571ebafc329ec9a9d457e37d41
2024-01-12 00:27:57 +08:00
hoshi-hiyouga
f95f6ec009 Merge pull request #2163 from JessyTsu1/main
请求添加"Projects using LLaMA Factory"

Former-commit-id: fa9abb430b204fabe4c1b3a569225695ae0cbc29
2024-01-11 23:33:29 +08:00
JessyTsu1
8aeecc20e1 Update README.md
Former-commit-id: 547d4df5c7a1d6dd95cfed37229701ce507b421c
2024-01-11 23:18:29 +08:00
JessyTsu1
38d0f6c63f Update README_zh.md
Former-commit-id: 8677309a38140ec1e1be3f81d0b2024df3f16c21
2024-01-11 23:17:48 +08:00
JessyTsu1
ac8534a9e7 Update README.md
Former-commit-id: dcd4858fd2c2ac4d3cce8a369dc9991108c03821
2024-01-11 23:17:00 +08:00
hiyouga
73cab9d9d4 fix #2161
Former-commit-id: 9acd5a2b678cd07f8e3b48eca76c4cbacb559e37
2024-01-11 17:04:13 +08:00
hiyouga
64246d42d2 improve web ui
Former-commit-id: 5c0148c018b12b52bc5748acfd6ad43836f2edb5
2024-01-10 12:37:45 +08:00
hiyouga
6fa6d4532e improve model export
Former-commit-id: d1b795aac1fccbcb8a9ec2057065c33b46ce1a5a
2024-01-09 22:26:24 +08:00
hiyouga
92b9956c06 modify weight name
Former-commit-id: 3f3c528fa8056dc1952ea5293bad7e55187983ff
2024-01-09 20:22:47 +08:00
hiyouga
4d6669c268 fix #1789
Former-commit-id: d86455f685fa531e651333e00b4fe54d895cf2e4
2024-01-09 18:31:27 +08:00
hiyouga
89f4ae51f9 fix #2127
Former-commit-id: 5a1aa33fa9b546ab520f0ba4cb9d996b87eb71ca
2024-01-09 14:49:13 +08:00
hiyouga
af0659f573 fix #2125
Former-commit-id: 46a22f4daeafac5b0a695212d060960ff53af613
2024-01-08 21:42:25 +08:00
hoshi-hiyouga
45a10d501e Merge pull request #2117 from dasdristanta13/main
Update requirements.txt With einops dependency

Former-commit-id: af0c05f1cffc7fc0fc74d514783333501f83f59e
2024-01-07 23:56:53 +08:00
Dristanta Das
e529ff1245 Update requirements.txt With einops dependency
Former-commit-id: 0b47b13cb34cace6fa0b6d0c58ca16fb01b3a5e9
2024-01-07 21:03:30 +05:30
hiyouga
b29371dc87 tiny fix
Former-commit-id: 06b854fe15eb4cf4ff8d6f5570068d9e74a2f1b3
2024-01-07 17:17:18 +08:00
hiyouga
0bef890000 fix api server
Former-commit-id: cedd80ba56c0090487f65f4b1227e5615943997f
2024-01-07 17:14:42 +08:00
hiyouga
75fe1404b1 improve model export
Former-commit-id: 31255147a566a23ce1a48402662d14af8ac267ab
2024-01-05 18:51:49 +08:00
hiyouga
b460c9372f fix #2098
Former-commit-id: e62d9158cffbf1044396597ddaf15b1c0bc5f954
2024-01-05 17:11:26 +08:00
hiyouga
c3e574ceaa fix qwen template
Former-commit-id: c1923e0daa02b49ac07e96ce29877729acc78d31
2024-01-05 16:14:56 +08:00
hiyouga
04ae80a52e fix #2081
Former-commit-id: ec4b539b6c0be11e15d273025c414b694bbd6c9a
2024-01-04 23:19:08 +08:00
hiyouga
a7ff095399 fix #2090
Former-commit-id: 13ec720990a88b01f7f5e2a99a87f95128dc3537
2024-01-04 23:05:08 +08:00
hiyouga
a655dcebaf fix #2067
Former-commit-id: 6cfdeea5261fd5bf6f91ba2bb3efb921a2f3e866
2024-01-04 22:53:03 +08:00
hiyouga
8c74851b70 fix dispatch
Former-commit-id: deda82638716506dc690902c51276bb1eb0ddd5e
2024-01-03 16:33:16 +08:00
hiyouga
7168392a51 fix valuehead patch
Former-commit-id: d9cb98362b58b28ae0ee207e7c07e75e5d810876
2024-01-03 16:19:23 +08:00
hiyouga
ccc5b324fe fix rm server
Former-commit-id: 81bc1638682a9fd01518f9f25250a6b584d2a9e6
2024-01-03 15:30:46 +08:00
hiyouga
e85c205a81 fix #2014
Former-commit-id: 077f6bf64e50f01f62aa4a957438bedc4e7925b3
2023-12-29 15:17:22 +08:00
hiyouga
7e225be16e add yuan model
Former-commit-id: 6a0377e2e51633bd5fb10fa8628e554565c5ee3e
2023-12-29 13:50:24 +08:00
hiyouga
ebb32e85f8 fix version
Former-commit-id: dd7500b65d0d548441eece101b60d51fa619cc0f
2023-12-29 04:53:36 +08:00
hiyouga
90d279f39f fix args
Former-commit-id: ff18f327a3dc96d9677ef32841e8f29ab2eeb7ef
2023-12-28 18:47:19 +08:00
hiyouga
af3f5b6e16 fix export format
Former-commit-id: 7c82bd396b9e6ff395850ad544d95cbf1b7557cd
2023-12-28 18:40:46 +08:00
hiyouga
53d7c5109f fix ppo trainer
Former-commit-id: ca5b5823b03822ef899405d233a82396be997f44
2023-12-28 18:09:28 +08:00
hiyouga
bf381563ff add model link
Former-commit-id: 159729929516f68aa1f43a852ed50ca0fac81523
2023-12-25 19:44:38 +08:00
hiyouga
de4b9334e1 tiny update
Former-commit-id: 4417b8ee20b381c964f452f52081667dfa33cd7b
2023-12-25 18:29:34 +08:00
hiyouga
c33fbea469 fix bug
Former-commit-id: b06faa1be3f5aa5e0fa31aa31314c213c36c3442
2023-12-24 19:20:12 +08:00
hiyouga
921f593632 update loader
Former-commit-id: 080d8eab858217ca58bffe719d5ffde7579c5bda
2023-12-24 19:10:23 +08:00
hiyouga
940403720a update patcher
Former-commit-id: d6d7b6670847ce4ea10353c5b126214542b45c2b
2023-12-23 15:24:27 +08:00
hiyouga
f869e44fe5 fix #1909
Former-commit-id: 3e93c33af9f80e28c9f30af9b7ba20757358afb4
2023-12-23 14:42:20 +08:00
hiyouga
bcc92919a0 update readme
Former-commit-id: d3dea7a926e9d356a39ca2033b03be7f559cc143
2023-12-23 02:17:41 +08:00
hiyouga
306a70c7ba fix unsloth dtype
Former-commit-id: fd22e6546ce5f38a6a075cf894aafc3d206b2fcd
2023-12-23 01:59:49 +08:00
hiyouga
d358d955e5 fix dpo trainer
Former-commit-id: c160dd7cd86e296e32775ace2e4258a473449c41
2023-12-23 01:51:55 +08:00
hiyouga
0fdd6074c3 llama board: add unsloth
Former-commit-id: 9477e6f28808ae9deadada1f6cf679a29542c271
2023-12-23 00:35:53 +08:00
hiyouga
6faf9c35a9 support unsloth
Former-commit-id: b857f00234b90b785d82ca7cdb29af3d948b1a7b
2023-12-23 00:14:33 +08:00
hoshi-hiyouga
1066898e32 Merge pull request #1953 from ShaneTian/model-load-bf16
Fix slow model initialization in bfloat16 dtype.

Former-commit-id: 69daf107c4561f807ceae066f04d432323699cef
2023-12-22 17:29:54 +08:00
ShaneTian
d05febe5de Fix slow model initialization in bfloat16 dtype.
Former-commit-id: cf2e2f6f9b7f09b1e2faf6fbc413e3f62e3846c7
2023-12-22 16:27:28 +08:00
hiyouga
67f7034a21 fix param type
Former-commit-id: 11b99f344416ade1cdac52e11ba7f36fcf689221
2023-12-21 17:33:01 +08:00
hiyouga
79f301a2c6 fix ds zero3 check
Former-commit-id: 7f50705b1d821d287bd854211319f697f57b25db
2023-12-21 01:19:22 +08:00
hiyouga
31cbc67986 match version
Former-commit-id: 16db52522584a8e084d4db2a7c253c8b88f27371
2023-12-20 22:17:35 +08:00
hoshi-hiyouga
fe66bf3663 Merge pull request #1932 from ShaneTian/main
Update transformers to 4.36.2 to resolve multi-node saving bug.

Former-commit-id: 5c55907a57e8327134e2c982c838a53c9fa42f51
2023-12-20 22:13:28 +08:00
ShaneTian
4691d4b35d Update transformers to 4.36.2 to resolve bug when saving a checkpoint in the multi-node setting.
Former-commit-id: 3173f8e51eec5e8f488e3dfc54ad371b640d6b87
2023-12-20 22:00:41 +08:00
hiyouga
acf5241845 fix stop words
Former-commit-id: 6ce6cac9fa8f0af33697e824cf93a9a80cdbd064
2023-12-20 19:06:43 +08:00
hiyouga
2bce99b82f fix yi template #1895
Former-commit-id: 05b4fa1e2b13a15ee261a151ac8cd0a2ebdf5edc
2023-12-20 18:58:16 +08:00
hiyouga
3c330869ef improve quantization
Former-commit-id: 4dde60017ad8208dfea0b2bb61df6a14a35d03e0
2023-12-20 18:27:16 +08:00
hiyouga
dba1af4841 add max_memory for gptq #1923
Former-commit-id: 9afc42c8b999fbbc206d9a467ca5795b27a10096
2023-12-20 18:15:17 +08:00
hiyouga
2b1e52dcc9 fix #1073 #1462 #1735 #1908
Former-commit-id: cd8e2535aa66931b24b96e76c2b56ce703a579b1
2023-12-20 17:15:40 +08:00
hiyouga
b5238e945a optimize data loading logic
Former-commit-id: 58f669b384582ac90e85de835f1f44f7003f9ec0
2023-12-20 16:15:41 +08:00
hiyouga
afc0f29704 fix #1909
Former-commit-id: f563e8d28dfa48a60cbe3d295b20f9cf58de296d
2023-12-20 16:11:07 +08:00
hiyouga
de0bb1d2da fix mixtral inference #1821
Former-commit-id: 612f9fd19cbd29e8b1785a1576a9668e7dcd264c
2023-12-20 15:11:15 +08:00
hiyouga
cc16ece283 fix #1900
Former-commit-id: 4c35214396f873588562606b084740b6581188d9
2023-12-19 17:21:46 +08:00
hiyouga
31ba802fc9 update readme
Former-commit-id: 36cd747e6a1a568e1a03e6c6611fec48e6ab9df7
2023-12-18 22:29:45 +08:00
hiyouga
4b27cf5460 add codegeex template
Former-commit-id: a8222722b8097158f1c92e3729f41d411eff3926
2023-12-18 19:52:35 +08:00
hiyouga
a53b2a643f add xverse-65B-2 model
Former-commit-id: 3e563a0d9666934dfdab54d61654ec00079a93f1
2023-12-18 19:24:09 +08:00
hiyouga
d925ecae1b add models
Former-commit-id: 3a4728557304996bcbe58d7d6380beead7c63c70
2023-12-18 19:09:31 +08:00
hiyouga
13fd751a78 fix tokenizer for Yi chat models #1617 #1875
Former-commit-id: 9485692c8d367a0b25d3e653db413aa01cb9ad7d
2023-12-18 17:18:11 +08:00
hiyouga
74575f8922 update readme
Former-commit-id: 01267eee0da0bffb3f0c0378e2e60d14e05585c4
2023-12-18 15:46:45 +08:00
hiyouga
5e7bb5fe73 fix llama board
Former-commit-id: f43f61b2898dda56aba0066fcb3409b152260bdb
2023-12-16 22:17:37 +08:00
hiyouga
790a31404a fix #1742
Former-commit-id: efbb32afdcf0d6aa4ca26f54c95f76dbb84f77dc
2023-12-16 20:50:45 +08:00
hiyouga
f927601702 add xverse-65b-chat model
Former-commit-id: fff6288db6b61ca27010ea47c918298f76922106
2023-12-16 20:21:29 +08:00
hiyouga
c4654d54d7 set version
Former-commit-id: 45a05e3a415eeaf875e2cf15bdba0235fbd7d527
2023-12-16 20:17:51 +08:00
hiyouga
df777c30d1 add noisy mean initialization #1815
Former-commit-id: 3253b1fca0123071913079277186c160046edf21
2023-12-16 19:47:51 +08:00
hiyouga
d81ad2d4bc support dpo-ftx
Former-commit-id: 86dfa04f9821556019fa777106787f73eb70b452
2023-12-16 19:21:41 +08:00
hiyouga
9f77e8b025 support autogptq in llama board #246
Former-commit-id: fea01226703d1534b5cf511bcb6a49e73bc86ce1
2023-12-16 16:31:30 +08:00
hoshi-hiyouga
04dc3f4614 Merge pull request #1868 from yhyu13/improve_hfargparser
Improve logging for unknown args

Former-commit-id: 6455013a99ca5c63f5b99c1100e93f794a03c497
2023-12-16 16:06:09 +08:00
yhyu13
7d1fe50977 Use llmtuner logger
Former-commit-id: ef5a560b4246e04e0ef2612e3520e05288e93707
2023-12-16 07:15:27 +00:00
yhyu13
c0e5e3c5d5 Improve logging for unknown args
Former-commit-id: 03e49d76ca91f7fcaf1c013740d5f6bfc11a2028
2023-12-16 05:16:29 +00:00
hiyouga
3a45cfb604 update tips
Former-commit-id: 4432cbda6b7535bcbb40ba77df069fca631b4be8
2023-12-15 23:52:50 +08:00
hiyouga
393e4b0f5a fix #1770
Former-commit-id: 8266187cec70bb4bd1b4837d51b09409ec11f93e
2023-12-15 23:50:15 +08:00
hiyouga
296711d502 support quantization in export model
Former-commit-id: f32500ae6edccab7d14df4c92467e15986866def
2023-12-15 23:44:50 +08:00
hiyouga
9121722999 update dc link
Former-commit-id: f6789e50e17a377b6d9b434d8e12ad99d8eecfeb
2023-12-15 22:11:31 +08:00
hoshi-hiyouga
d8d74091f6 Merge pull request #1864 from hiyouga/dev
Refactor hyper-parameters of adapters and model loader

Former-commit-id: d5ce2fb6858b9f2963f355e9f4d6f046eb6efdcd
2023-12-15 22:06:56 +08:00
hiyouga
33521fb45e fix bug
Former-commit-id: 95ac272907a04a64785f928536de1fd099150f92
2023-12-15 21:54:02 +08:00
hiyouga
e5204e60ed fix bug
Former-commit-id: 8b80baf02cfece53527c27712f0899fa3532c414
2023-12-15 21:49:26 +08:00
hiyouga
0409428d87 add configurer
Former-commit-id: c40c9889615ffb49c7ce24c69c0d3d20d841c800
2023-12-15 21:46:40 +08:00
hiyouga
f902b0d420 refactor adapter hparam
Former-commit-id: f82aece9ebd6df83a7a005cc7cbbcec07fa6e14d
2023-12-15 20:53:11 +08:00
hiyouga
27ef5b1aa7 add loftq
Former-commit-id: 0b900882ef19ac49604a24fbae8b3254f1bff7ad
2023-12-14 21:53:56 +08:00
hiyouga
c32303fc7e fix valuehead model
Former-commit-id: 9f628debb6510f2d1c91b00f121a721ab5d648e9
2023-12-14 20:15:20 +08:00
hoshi-hiyouga
45abe361ba tiny fix
Former-commit-id: 987df4c62f34026adfe2089910f4ff9ac6ebd9a6
2023-12-13 17:32:36 +08:00
hoshi-hiyouga
3ae479faae revert peft version
Former-commit-id: 6440fa1a8c28fd2db58d0905a67d071837e0edd1
2023-12-13 10:49:45 +08:00
hoshi-hiyouga
5698038f49 update peft version
Former-commit-id: 31c01e1272bd2cd9588e5ee68c1924a3dd55c67e
2023-12-13 10:23:51 +08:00
hoshi-hiyouga
020233f725 tiny fix
Former-commit-id: 1478bc052417e0939188f55a0adcbf00956960f2
2023-12-13 10:21:29 +08:00
hoshi-hiyouga
6f9d55b8eb fix #1819
Former-commit-id: f2e2b0354cbe9a7190ccab807f690cc8ab433a6e
2023-12-13 10:14:01 +08:00
hiyouga
2542b62d77 remove loftq
Former-commit-id: e175c0a1c631296117abda2403a4b87bbdd35a66
2023-12-13 01:53:46 +08:00
hiyouga
95678bb6b1 fix sharegpt loading
Former-commit-id: ad35c35f9328bff69e8b9ea7dba6a61a2dc9e28b
2023-12-13 00:56:16 +08:00
hiyouga
a78759e7ee add model urls
Former-commit-id: 3139a9fafab246f5461697efd5ed7a6599d85481
2023-12-13 00:09:17 +08:00
hiyouga
cc5c523f58 update readme
Former-commit-id: e81037d766f89f7e2b6539596397983eba52b492
2023-12-12 23:30:29 +08:00
hiyouga
e39bbdd287 support loftq
Former-commit-id: e7ac2eb7f7daae17525a278ffbe2f82c0fbd8093
2023-12-12 22:47:06 +08:00
hiyouga
d9a50bf93f fix #1795
Former-commit-id: 949ab45487155525789c08027d4f8e7da1b8bc0c
2023-12-12 19:58:34 +08:00
hiyouga
934d00ea1e support system column #1765
Former-commit-id: f425584a511c5e42bae8b3ba090eaa898b28adad
2023-12-12 19:45:59 +08:00
hiyouga
c27675f70d fix modelscope data hub
Former-commit-id: 5b63e8c22538a4788e4b6c8df50e6e6be93ceeac
2023-12-12 18:33:06 +08:00
hoshi-hiyouga
7c9f37c83d Merge pull request #1802 from tastelikefeet/feat/support_ms
Support ModelScope Datahub

Former-commit-id: f73f321e765aab9325673218779ff4ee7f281514
2023-12-12 17:58:37 +08:00
hoshi-hiyouga
b9736c13e0 Merge branch 'main' into feat/support_ms
Former-commit-id: 698756dffb7d4e602b3e0cab66ef0a4befe7215c
2023-12-12 17:55:32 +08:00
hiyouga
c47725ff34 fix webui
Former-commit-id: 15ad266206b12181788db5bb112c2299050d6139
2023-12-12 15:27:40 +08:00
xingjun.wang
3ee3fe0bbb add use_streaming
Former-commit-id: 80388abdb7ee88eb4afad92d8c706370c0574039
2023-12-12 14:23:05 +08:00
xingjun.wang
e54dad75da fix cache dir
Former-commit-id: 6231272b9c51d44196f1fbec026973231e489b67
2023-12-12 14:21:33 +08:00
xingjun.wang
39c2f03eab add print info for test
Former-commit-id: e4ae2fccf0cbec57fb5fb01fd7cc352da69b23bf
2023-12-12 14:14:40 +08:00
xingjun.wang
fb9e1c4087 update cache dir
Former-commit-id: c8a1ce847fd7a75a06659133d92a0ac42e52a839
2023-12-12 13:08:18 +08:00
xingjun.wang
ed26bb3d82 update args for MsDataset.load
Former-commit-id: c5f69357a167cbf99a93607177526e787419ea05
2023-12-12 13:02:54 +08:00
xingjun.wang
0baf32e219 update
Former-commit-id: e15fc417d897c3063a25d6eb7eb89d1916db3cc5
2023-12-12 12:03:23 +08:00
xingjun.wang
79a376d1db for test
Former-commit-id: 33d9082320098f994bfa0c6353459afcb93165b7
2023-12-12 11:52:59 +08:00
xingjun.wang
b634e91c43 for test
Former-commit-id: 95ea942bd32402018e7c5dc61d50153c602ab67a
2023-12-12 11:47:59 +08:00
hiyouga
9e2cc21d04 update readme
Former-commit-id: 42e042a4206aeb5177ddde56386e9655b0c06460
2023-12-12 11:44:30 +08:00
hiyouga
6975124a57 support mixtral
Former-commit-id: 75b5b8e36ab1933b2625f11b645f56cbc805fd85
2023-12-12 11:39:04 +08:00
hiyouga
9f69307db1 fix baichuan resize
Former-commit-id: 66956d13074a9bc74d7a737b9476f38361a7764a
2023-12-11 20:55:50 +08:00
hiyouga
c3448a045c tiny fix
Former-commit-id: 1f839fc4f278c2a258df22899241fc66a2cca682
2023-12-11 18:09:40 +08:00
hiyouga
95c561983c support resize embeddings #1786
Former-commit-id: 368a41bd3c6a04f869083058d9165954fbdad105
2023-12-11 17:50:02 +08:00
hiyouga
7a03c8dab5 use peft 0.7.0, fix #1561 #1764
Former-commit-id: 423947bd58aa50da8785b8ceca1e7e288447a9da
2023-12-11 17:13:40 +08:00
hiyouga
f3ffa8310f fix #1784
Former-commit-id: 4e1af5a5d39d9e2f374c1372e2d67120c63fea09
2023-12-09 20:53:18 +08:00
yuze.zyz
596f496f19 support ms dataset
Former-commit-id: 98638b35dc24045ac17b9b01d08d3a02372acef3
2023-12-08 18:00:57 +08:00
hiyouga
2e6ed731cf fix #1771 and temporarily fix #1764
Former-commit-id: d0e5a5d604e16c2fe0035b0ac1d54dc3625d4da3
2023-12-08 16:26:20 +08:00
hiyouga
24ce319b6f add models
Former-commit-id: 758ae7937a41a95016e70180fb343011763c1b67
2023-12-06 13:33:18 +08:00
hiyouga
7b7bfea37d fix ppo trainer save logic
Former-commit-id: 5e70c41e4e12a1109570b0ff56346fe212c028ed
2023-12-04 19:00:19 +08:00
hiyouga
3be461260a update readme
Former-commit-id: a15f8cf19cac42acfb9917a2d7c9fa36a838b360
2023-12-04 11:22:01 +08:00
hiyouga
8dab8d9831 update readme
Former-commit-id: d3c46cb126a9182be765341fe31c860d71430712
2023-12-04 11:02:29 +08:00
hiyouga
fb4c5f3c91 fix #1715
Former-commit-id: 3f9192dbbbafdc2171d2eb80282d5cae47565b7b
2023-12-03 22:35:47 +08:00
118 changed files with 6599 additions and 3620 deletions

7
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@@ -0,0 +1,7 @@
# What does this PR do?
Fixes # (issue)
## Before submitting
- [ ] Did you read the [contributor guideline](/CONTRIBUTING.md)?

29
.github/workflows/tests.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: tests
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
check_code_quality:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ruff
- name: Check quality
run: |
make style && make quality

21
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,21 @@
# Contributing to LLaMA Factory
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
**This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
## Ways to contribute
There are several ways you can contribute to LLaMA Factory:
* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
* Contribute to the examples or to the documentation.
### Style guide
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.

11
Makefile Normal file
View File

@@ -0,0 +1,11 @@
.PHONY: quality style
check_dirs := src tests
quality:
ruff $(check_dirs)
ruff format --check $(check_dirs)
style:
ruff $(check_dirs) --fix
ruff format $(check_dirs)

226
README.md
View File

@@ -5,8 +5,9 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/Citation-21-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20In%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20In%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
@@ -16,9 +17,7 @@
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory ## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**. Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** and **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**, or launch it locally with `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`.
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU. Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## Table of Contents ## Table of Contents
- [Features](#features)
- [Benchmark](#benchmark) - [Benchmark](#benchmark)
- [Changelog](#changelog) - [Changelog](#changelog)
- [Supported Models](#supported-models) - [Supported Models](#supported-models)
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [Citation](#citation) - [Citation](#citation)
- [Acknowledgement](#acknowledgement) - [Acknowledgement](#acknowledgement)
## Features
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA, 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ, agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune, rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
## Benchmark ## Benchmark
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory. Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
@@ -55,17 +64,29 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage. [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`. [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention. [23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. [23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
@@ -91,19 +112,24 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| Model | Model size | Default module | Template | | Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | | [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | | [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE] > [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules. > **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
@@ -114,7 +140,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
## Supported Training Approaches ## Supported Training Approaches
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA | | Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
@@ -123,7 +149,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE] > [!NOTE]
> Use `--quantization_bit 4/8` argument to enable QLoRA. > Use `--quantization_bit 4` argument to enable QLoRA.
## Provided Datasets ## Provided Datasets
@@ -145,8 +171,8 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json) - [Self Cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
@@ -162,11 +188,14 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
@@ -174,6 +203,16 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
</details> </details>
@@ -183,6 +222,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
@@ -197,22 +237,34 @@ huggingface-cli login
## Requirement ## Requirement
- Python 3.8+ and PyTorch 1.13.1+ | Mandatory | Minimum | Recommend |
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL | ------------ | ------- | --------- |
- sentencepiece, protobuf and tiktoken | python | 3.8 | 3.10 |
- jieba, rouge-chinese and nltk (used at evaluation and predict) | torch | 1.13.1 | 2.2.1 |
- gradio and matplotlib (used in web UI) | transformers | 4.37.2 | 4.38.1 |
- uvicorn, fastapi and sse-starlette (used in API) | datasets | 2.14.3 | 2.17.1 |
| accelerate | 0.27.2 | 0.27.2 |
| peft | 0.9.0 | 0.9.0 |
| trl | 0.7.11 | 0.7.11 |
| Optional | Minimum | Recommend |
| ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.13.4 |
| bitsandbytes | 0.39.0 | 0.41.3 |
| flash-attn | 2.3.0 | 2.5.5 |
### Hardware Requirement ### Hardware Requirement
| Method | Bits | 7B | 13B | 30B | 65B | \* *estimated*
| ------ | ---- | ----- | ----- | ----- | ------ |
| Full | 16 | 140GB | 240GB | 520GB | 1200GB | | Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | | ------ | ---- | ----- | ----- | ----- | ------ | ------ |
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | | Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | | Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | | LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
## Getting Started ## Getting Started
@@ -233,15 +285,17 @@ cd LLaMA-Factory
pip install -r requirements.txt pip install -r requirements.txt
``` ```
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1. If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2.
```bash ```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
``` ```
### Use ModelScope Models (optional) To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
If you have trouble with downloading models from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner. ### Use ModelScope Hub (optional)
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
```bash ```bash
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
@@ -255,7 +309,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
... # arguments (same as above) ... # arguments (same as above)
``` ```
LLaMA Board also supports using the models on the ModelScope Hub. LLaMA Board also supports using the models and datasets on the ModelScope Hub.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
@@ -271,8 +325,8 @@ CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \ --stage pt \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \ --dataset wiki_demo \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
@@ -294,8 +348,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
@@ -318,14 +372,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \ --stage rm \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \ --dataset comparison_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
@@ -343,14 +397,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \ --stage ppo \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \ --reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \ --output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \
@@ -366,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
> [!WARNING] > [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training. > Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
@@ -374,14 +431,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \ --stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \ --dataset comparison_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \ --output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
@@ -394,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
### Distributed Training ### Distributed Training
#### Use Huggingface Accelerate #### Use Huggingface Accelerate
@@ -407,6 +467,7 @@ accelerate launch src/train_bash.py # arguments (same as above)
```yaml ```yaml
compute_environment: LOCAL_MACHINE compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU distributed_type: MULTI_GPU
downcast_bf16: 'no' downcast_bf16: 'no'
gpu_ids: all gpu_ids: all
@@ -449,7 +510,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
"loss_scale_window": 1000, "loss_scale_window": 1000,
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"allgather_partitions": true, "allgather_partitions": true,
@@ -469,43 +530,51 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
```bash ```bash
python src/export_model.py \ python src/export_model.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --export_dir path_to_export \
--export_dir path_to_export --export_size 2 \
--export_legacy_format False
``` ```
### API Demo > [!WARNING]
> Merging LoRA weights into a quantized model is not supported.
> [!TIP]
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model after merging the LoRA weights.
### Inference with OpenAI-style API
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
> [!TIP] > [!TIP]
> Visit `http://localhost:8000/docs` for API documentation. > Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo ### Inference with command line
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### Web Demo ### Inference with web browser
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### Evaluation ### Evaluation
@@ -513,9 +582,9 @@ python src/web_demo.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--finetuning_type lora \ --adapter_name_or_path path_to_checkpoint \
--checkpoint_dir path_to_checkpoint \
--template vanilla \ --template vanilla \
--finetuning_type lora \
--task mmlu \ --task mmlu \
--split test \ --split test \
--lang en \ --lang en \
@@ -528,14 +597,14 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_llama_model \
--do_predict \ --do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \ --output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \ --per_device_eval_batch_size 1 \
--max_samples 100 \ --max_samples 100 \
--predict_with_generate \ --predict_with_generate \
--fp16 --fp16
@@ -549,10 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
## Projects using LLaMA Factory ## Projects using LLaMA Factory
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
> [!TIP] > [!TIP]
> If you have a project that should be incorporated, please contact via email or create a pull request. > If you have a project that should be incorporated, please contact via email or create a pull request.
@@ -561,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -5,8 +5,9 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/Citation-21-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20In%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20In%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
@@ -16,9 +17,7 @@
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory ## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。 通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board,或者通过命令 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 本地启动
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。该模式目前仅支持单卡训练
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。 下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## 目录 ## 目录
- [项目特色](#项目特色)
- [性能指标](#性能指标) - [性能指标](#性能指标)
- [更新日志](#更新日志) - [更新日志](#更新日志)
- [模型](#模型) - [模型](#模型)
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [引用](#引用) - [引用](#引用)
- [致谢](#致谢) - [致谢](#致谢)
## 项目特色
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
## 性能指标 ## 性能指标
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。 与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
@@ -55,23 +64,35 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## 更新日志 ## 更新日志
[23/12/01] 我们支持了 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过) [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune例如 `--neft_alpha 5` [24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune例如 `--neftune_noise_alpha 5`
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。 [23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2。 [23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2。
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 [23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。 [23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming``--max_steps 10000` 参数来流式加载数据集。 [23/07/31] 我们支持了**数据流式加载**。请使用 `--streaming``--max_steps 10000` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
@@ -91,19 +112,24 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | | [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | | [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE] > [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。 > **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
@@ -123,7 +149,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE] > [!NOTE]
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。 > 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
## 数据集 ## 数据集
@@ -145,8 +171,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json) - [Self Cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
@@ -162,11 +188,14 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
@@ -174,6 +203,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
</details> </details>
@@ -183,6 +222,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
@@ -197,22 +237,34 @@ huggingface-cli login
## 软硬件依赖 ## 软硬件依赖
- Python 3.8+ 和 PyTorch 1.13.1+ | 必需项 | 至少 | 推荐 |
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL | ------------ | ------- | --------- |
- sentencepiece, protobuf 和 tiktoken | python | 3.8 | 3.10 |
- jieba, rouge-chinese 和 nltk (用于评估及预测) | torch | 1.13.1 | 2.2.1 |
- gradio 和 matplotlib (用于网页端交互) | transformers | 4.37.2 | 4.38.1 |
- uvicorn, fastapi 和 sse-starlette (用于 API) | datasets | 2.14.3 | 2.17.1 |
| accelerate | 0.27.2 | 0.27.2 |
| peft | 0.9.0 | 0.9.0 |
| trl | 0.7.11 | 0.7.11 |
| 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.13.4 |
| bitsandbytes | 0.39.0 | 0.41.3 |
| flash-attn | 2.3.0 | 2.5.5 |
### 硬件依赖 ### 硬件依赖
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | \* *估算值*
| ------- | ---- | ----- | ----- | ----- | ------ |
| 全参数 | 16 | 140GB | 240GB | 520GB | 1200GB | | 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | | ------- | ---- | ----- | ----- | ----- | ------ | ------ |
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | | 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | | 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | | LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
## 如何使用 ## 如何使用
@@ -233,15 +285,17 @@ cd LLaMA-Factory
pip install -r requirements.txt pip install -r requirements.txt
``` ```
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1. 如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2。
```bash ```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
``` ```
如果要在 Windows 平台上开启 FlashAttention-2需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
### 使用魔搭社区(可跳过) ### 使用魔搭社区(可跳过)
如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区。 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
```bash ```bash
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
@@ -255,7 +309,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
... # 参数同上 ... # 参数同上
``` ```
LLaMA Board 同样支持魔搭社区的模型下载。 LLaMA Board 同样支持魔搭社区的模型和数据集下载。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
@@ -271,8 +325,8 @@ CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \ --stage pt \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \ --dataset wiki_demo \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
@@ -294,8 +348,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_zh \ --dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
@@ -318,14 +372,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \ --stage rm \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \ --dataset comparison_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
@@ -343,14 +397,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \ --stage ppo \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_zh \ --dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \ --reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \ --output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \
@@ -366,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
> [!WARNING] > [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。 > 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
@@ -374,14 +431,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \ --stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \ --dataset comparison_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \ --output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
@@ -394,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
### 多 GPU 分布式训练 ### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate #### 使用 Huggingface Accelerate
@@ -407,6 +467,7 @@ accelerate launch src/train_bash.py # 参数同上
```yaml ```yaml
compute_environment: LOCAL_MACHINE compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU distributed_type: MULTI_GPU
downcast_bf16: 'no' downcast_bf16: 'no'
gpu_ids: all gpu_ids: all
@@ -449,7 +510,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
"loss_scale_window": 1000, "loss_scale_window": 1000,
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"allgather_partitions": true, "allgather_partitions": true,
@@ -464,48 +525,56 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
</details> </details>
### 合并 LoRA 权重并导出完整模型 ### 合并 LoRA 权重并导出模型
```bash ```bash
python src/export_model.py \ python src/export_model.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --export_dir path_to_export \
--export_dir path_to_export --export_size 2 \
--export_legacy_format False
``` ```
### API 服务 > [!WARNING]
> 尚不支持量化模型的 LoRA 权重合并及导出。
> [!TIP]
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化模型。
### 使用 OpenAI 风格 API 推理
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
> [!TIP] > [!TIP]
> 关于 API 文档请见 `http://localhost:8000/docs`。 > 关于 API 文档请见 `http://localhost:8000/docs`。
### 命令行测试 ### 使用命令行推理
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### 浏览器测试 ### 使用浏览器推理
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### 模型评估 ### 模型评估
@@ -513,9 +582,9 @@ python src/web_demo.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--finetuning_type lora \ --adapter_name_or_path path_to_checkpoint \
--checkpoint_dir path_to_checkpoint \
--template vanilla \ --template vanilla \
--finetuning_type lora \
--task ceval \ --task ceval \
--split validation \ --split validation \
--lang zh \ --lang zh \
@@ -528,14 +597,14 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_llama_model \
--do_predict \ --do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_zh \ --dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \ --output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \ --per_device_eval_batch_size 1 \
--max_samples 100 \ --max_samples 100 \
--predict_with_generate \ --predict_with_generate \
--fp16 --fp16
@@ -549,10 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
> [!TIP] > [!TIP]
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。 > 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
@@ -561,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

7
SECURITY.md Normal file
View File

@@ -0,0 +1,7 @@
# Reporting Security Issues
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/electron/electron/security/advisories/new) tab.
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
Report security bugs in third-party modules to the person or team maintaining the module.

View File

@@ -2,21 +2,32 @@ If you are using a custom dataset, please provide your dataset definition in the
```json ```json
"dataset_name": { "dataset_name": {
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)", "hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", "ms_hub_url": "the name of the dataset repository on the ModelScope hub. (if specified, ignore script_url and file_name)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)", "file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
"subset": "the name of the subset. (optional, default: None)", "subset": "the name of the subset. (optional, default: None)",
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)", "ranking": "whether the dataset is a preference dataset or not. (default: false)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns": { "columns (optional)": {
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)", "prompt": "the column name in the dataset containing the prompts. (default: instruction)",
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)", "query": "the column name in the dataset containing the queries. (default: input)",
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)", "response": "the column name in the dataset containing the responses. (default: output)",
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)", "history": "the column name in the dataset containing the histories. (default: None)",
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)", "messages": "the column name in the dataset containing the messages. (default: conversations)",
"role": "the key in the message represents the identity. (default: from, for sharegpt)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"content": "the key in the message represents the content. (default: value, for sharegpt)" "tools": "the column name in the dataset containing the tool description. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
"content_tag": "the key in the message represents the content. (default: value)",
"user_tag": "the value of the role_tag represents the user. (default: human)",
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
} }
} }
``` ```
@@ -31,6 +42,7 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
"instruction": "user instruction (required)", "instruction": "user instruction (required)",
"input": "user input (optional)", "input": "user input (optional)",
"output": "model response (required)", "output": "model response (required)",
"system": "system prompt (optional)",
"history": [ "history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"], ["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"] ["user instruction in the second round (optional)", "model response in the second round (optional)"]
@@ -47,14 +59,15 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
"prompt": "instruction", "prompt": "instruction",
"query": "input", "query": "input",
"response": "output", "response": "output",
"system": "system",
"history": "history" "history": "history"
} }
} }
``` ```
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**. The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
For the pre-training datasets, only the `prompt` column will be used for training. For the pre-training datasets, only the `prompt` column will be used for training.
@@ -85,7 +98,9 @@ The dataset in sharegpt format should follow the below format:
"from": "gpt", "from": "gpt",
"value": "model response" "value": "model response"
} }
] ],
"system": "system prompt (optional)",
"tools": "tool description (optional)"
} }
] ]
``` ```
@@ -96,12 +111,18 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
"dataset_name": { "dataset_name": {
"columns": { "columns": {
"messages": "conversations", "messages": "conversations",
"role": "from", "system": "system",
"content": "value" "tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order. where the `messages` column should be a list following the `u/a/u/a/u/a` order.
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet. Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.

View File

@@ -2,21 +2,32 @@
```json ```json
"数据集名称": { "数据集名称": {
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数", "hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数", "ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选留空不影响训练", "file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
"subset": "数据集子集的名称可选默认None", "subset": "数据集子集的名称可选默认None",
"folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集可选默认False", "ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt", "formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns": { "columns(可选)": {
"prompt": "数据集代表提示词的表头名称默认instruction,用于 alpaca 格式", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input,用于 alpaca 格式", "query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output,用于 alpaca 格式", "response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None,用于 alpaca 格式", "history": "数据集代表历史对话的表头名称默认None",
"messages": "数据集代表消息列表的表头名称默认conversations,用于 sharegpt 格式", "messages": "数据集代表消息列表的表头名称默认conversations",
"role": "消息中代表发送者身份的键名默认from用于 sharegpt 格式", "system": "数据集代表系统提示的表头名称默认None",
"content": "消息中代表文本内容的键名默认value用于 sharegpt 格式" "tools": "数据集代表工具描述的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
"content_tag": "消息中代表文本内容的键名默认value",
"user_tag": "消息中代表用户的 role_tag默认human",
"assistant_tag": "消息中代表助手的 role_tag默认gpt",
"observation_tag": "消息中代表工具返回结果的 role_tag默认observation",
"function_tag": "消息中代表工具调用的 role_tag默认function_call",
"system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system 列)"
} }
} }
``` ```
@@ -31,6 +42,7 @@
"instruction": "用户指令(必填)", "instruction": "用户指令(必填)",
"input": "用户输入(选填)", "input": "用户输入(选填)",
"output": "模型回答(必填)", "output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [ "history": [
["第一轮指令(选填)", "第一轮回答(选填)"], ["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"] ["第二轮指令(选填)", "第二轮回答(选填)"]
@@ -47,14 +59,15 @@
"prompt": "instruction", "prompt": "instruction",
"query": "input", "query": "input",
"response": "output", "response": "output",
"system": "system",
"history": "history" "history": "history"
} }
} }
``` ```
其中 `prompt``response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入 其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**会被用于训练**。 `system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**会被用于训练**。
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。 对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
@@ -85,7 +98,9 @@
"from": "gpt", "from": "gpt",
"value": "模型回答" "value": "模型回答"
} }
] ],
"system": "系统提示词(选填)",
"tools": "工具描述(选填)"
} }
] ]
``` ```
@@ -96,12 +111,18 @@
"数据集名称": { "数据集名称": {
"columns": { "columns": {
"messages": "conversations", "messages": "conversations",
"role": "from", "system": "system",
"content": "value" "tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
其中 `messages`必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。 其中 `messages`应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
预训练数据集和偏好数据集尚不支持 sharegpt 格式。 预训练数据集和偏好数据集尚不支持 sharegpt 格式。

View File

@@ -1 +1 @@
fc9a6a3458caca2af8dafc6181773fe10c6d8657 34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b

View File

@@ -0,0 +1 @@
4748dff00d1dc42768a5b6cc772143c313017812

View File

@@ -1 +0,0 @@
38c89869c6aeca2a3af9ea1e09afe460f9b46810

View File

@@ -0,0 +1,29 @@
#!/bin/bash
deepspeed --num_gpus 4 ../../src/train_bash.py \
--deepspeed ds_z3_config.json \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en \
--dataset_dir ../../data \
--template default \
--finetuning_type full \
--output_dir ../../saves/LLaMA2-7B/full/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,30 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file config.yaml ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--create_new_adapter \
--dataset comparison_gpt4_en \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/dpo \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--max_samples 1000 \
--val_size 0.1 \
--dpo_ftx 1.0 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,31 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage ppo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--create_new_adapter \
--dataset alpaca_gpt4_en \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--reward_model ../../saves/LLaMA2-7B/lora/reward \
--output_dir ../../saves/LLaMA2-7B/lora/ppo \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 512 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--max_samples 1000 \
--top_k 0 \
--top_p 0.9 \
--max_new_tokens 256 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,18 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft,../../saves/LLaMA2-7B/lora/dpo \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--output_dir ../../saves/LLaMA2-7B/lora/predict \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_eval_batch_size 1 \
--max_samples 20 \
--predict_with_generate

View File

@@ -0,0 +1,29 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage pt \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset c4_demo \
--dataset_dir ../../data \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/pretrain \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 10000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,31 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage rm \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--create_new_adapter \
--dataset comparison_gpt4_en \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/reward \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--max_samples 5000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,30 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,30 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,30 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,31 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--quantization_bit 4 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,30 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -1,3 +1,32 @@
[build-system] [build-system]
requires = ["setuptools>=61.0"] requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.ruff]
target-version = "py38"
line-length = 119
indent-width = 4
[tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["llmtuner"]
known-third-party = [
"accelerate",
"datasets",
"gradio",
"numpy",
"peft",
"torch",
"transformers",
"trl"
]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

View File

@@ -1,14 +1,14 @@
torch>=1.13.1 torch>=1.13.1
transformers>=4.31.0,<4.35.0 transformers>=4.37.2
datasets>=2.14.0 datasets>=2.14.3
accelerate>=0.21.0 accelerate>=0.27.2
peft>=0.6.0 peft>=0.9.0
trl>=0.7.4 trl>=0.7.11
gradio>=3.38.0,<4.0.0 gradio>=3.38.0,<4.0.0
scipy scipy
einops
sentencepiece sentencepiece
protobuf protobuf
tiktoken
jieba jieba
rouge-chinese rouge-chinese
nltk nltk

View File

@@ -1,3 +1,5 @@
import os
import uvicorn import uvicorn
from llmtuner import ChatModel, create_app from llmtuner import ChatModel, create_app
@@ -6,8 +8,8 @@ from llmtuner import ChatModel, create_app
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
print("Visit http://localhost:8000/docs for API document.") print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,17 +1,19 @@
from llmtuner import ChatModel from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
try: try:
import platform import platform
if platform.system() != "Windows": if platform.system() != "Windows":
import readline import readline # noqa: F401
except ImportError: except ImportError:
print("Install `readline` for a better experience.") print("Install `readline` for a better experience.")
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
history = [] messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True: while True:
@@ -27,20 +29,20 @@ def main():
break break
if query.strip() == "clear": if query.strip() == "clear":
history = [] messages = []
torch_gc() torch_gc()
print("History has been removed.") print("History has been removed.")
continue continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
response = "" response = ""
for new_text in chat_model.stream_chat(query, history): for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True) print(new_text, end="", flush=True)
response += new_text response += new_text
print() print()
messages.append({"role": "assistant", "content": response})
history = history + [(query, response)]
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,10 +1,11 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams # Level: api, webui > chat, eval, train > data, model > extras, hparams
from llmtuner.api import create_app from .api import create_app
from llmtuner.chat import ChatModel from .chat import ChatModel
from llmtuner.eval import Evaluator from .eval import Evaluator
from llmtuner.train import export_model, run_exp from .train import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo from .webui import create_ui, create_web_demo
__version__ = "0.3.3" __version__ = "0.5.3"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -1 +1,4 @@
from llmtuner.api.app import create_app from .app import create_app
__all__ = ["create_app"]

View File

@@ -1,28 +1,31 @@
import asyncio
import json import json
from typing import List, Tuple import os
from pydantic import BaseModel
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, Dict, Sequence
from llmtuner.api.protocol import ( from pydantic import BaseModel
Role,
Finish, from ..chat import ChatModel
ModelCard, from ..data import Role as DataRole
ModelList, from ..extras.misc import torch_gc
ChatMessage, from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
DeltaMessage, from .protocol import (
ChatCompletionMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage, ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
Finish,
Function,
FunctionCall,
ModelCard,
ModelList,
Role,
ScoreEvaluationRequest, ScoreEvaluationRequest,
ScoreEvaluationResponse ScoreEvaluationResponse,
)
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
) )
@@ -40,15 +43,22 @@ if is_uvicorn_available():
@asynccontextmanager @asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory async def lifespan(app: "FastAPI"): # collects GPU memory
yield yield
torch_gc() torch_gc()
def to_json(data: BaseModel) -> str: def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2 try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except: # pydantic v1 except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False) return data.json(exclude_unset=True, ensure_ascii=False)
@@ -63,6 +73,15 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_headers=["*"], allow_headers=["*"],
) )
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
role_mapping = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
@app.get("/v1/models", response_model=ModelList) @app.get("/v1/models", response_model=ModelList)
async def list_models(): async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id="gpt-3.5-turbo")
@@ -73,92 +92,123 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if not chat_model.can_generate: if not chat_model.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
query = request.messages[-1].content if request.messages[0].role == Role.SYSTEM:
prev_messages = request.messages[:-1] system = request.messages.pop(0).content
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content
else: else:
system = None system = ""
history = [] if len(request.messages) % 2 == 0:
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = predict(query, history, system, request) if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat( responses = chat_model.chat(
query, history, system, messages,
system,
tools,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
num_return_sequences=request.n num_return_sequences=request.n,
) )
prompt_length, response_length = 0, 0 prompt_length, response_length = 0, 0
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice( if tools:
index=i, result = chat_model.template.format_tools.extract(response.response_text)
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text), else:
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH result = response.response_text
))
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
)
prompt_length = response.prompt_length prompt_length = response.prompt_length
response_length += response.response_length response_length += response.response_length
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length, prompt_tokens=prompt_length,
completion_tokens=response_length, completion_tokens=response_length,
total_tokens=prompt_length+response_length total_tokens=prompt_length + response_length,
) )
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): def stream_chat_completion(
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield jsonify(chunk)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, system, messages,
system,
tools,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens max_new_tokens=request.max_tokens,
): ):
if len(new_text) == 0: if len(new_text) == 0:
continue continue
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
delta=DeltaMessage(content=new_text),
finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
delta=DeltaMessage(),
finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield jsonify(chunk)
yield "[DONE]" yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK) @app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
@@ -168,7 +218,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_score, request)
def get_score(request: ScoreEvaluationRequest):
scores = chat_model.get_scores(request.messages, max_length=request.max_length) scores = chat_model.get_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(model=request.model, scores=scores)
@@ -178,4 +233,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if __name__ == "__main__": if __name__ == "__main__":
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)

View File

@@ -1,30 +1,48 @@
import time import time
from enum import Enum from enum import Enum, unique
from pydantic import BaseModel, Field
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
class Role(str, Enum): class Role(str, Enum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
@unique
class Finish(str, Enum): class Finish(str, Enum):
STOP = "stop" STOP = "stop"
LENGTH = "length" LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: Optional[str] = "model" object: Literal["model"] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner" owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel): class ModelList(BaseModel):
object: Optional[str] = "list" object: Literal["list"] = "list"
data: Optional[List[ModelCard]] = [] data: List[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
function: Function
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@@ -32,31 +50,33 @@ class ChatMessage(BaseModel):
content: str content: str
class DeltaMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
do_sample: Optional[bool] = True tools: Optional[list] = []
do_sample: bool = True
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
n: Optional[int] = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: Optional[bool] = False stream: bool = False
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatCompletionMessage
finish_reason: Finish finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None finish_reason: Optional[Finish] = None
@@ -67,18 +87,18 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Optional[str] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Optional[str] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
@@ -90,7 +110,7 @@ class ScoreEvaluationRequest(BaseModel):
class ScoreEvaluationResponse(BaseModel): class ScoreEvaluationResponse(BaseModel):
id: Optional[str] = "scoreeval-default" id: Literal["scoreeval-default"] = "scoreeval-default"
object: Optional[str] = "score.evaluation" object: Literal["score.evaluation"] = "score.evaluation"
model: str model: str
scores: List[float] scores: List[float]

View File

@@ -1 +1,4 @@
from llmtuner.chat.chat_model import ChatModel from .chat_model import ChatModel
__all__ = ["ChatModel"]

View File

@@ -1,18 +1,18 @@
import torch
import tiktoken
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.data.template import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer
@dataclass @dataclass
class Response: class Response:
response_text: str response_text: str
response_length: int response_length: int
prompt_length: int prompt_length: int
@@ -20,28 +20,26 @@ class Response:
class ChatModel: class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.can_generate = (finetuning_args.stage == "sft") self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer( self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) )
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.system_prompt = data_args.system_prompt
def _process_args( def _process_args(
self, self,
query: str, messages: Sequence[Dict[str, str]],
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
**input_kwargs tools: Optional[str] = None,
**input_kwargs,
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
system = system or self.system_prompt paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
) )
prompt_length = len(prompt) prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device) input_ids = torch.tensor([prompt], device=self.model.device)
@@ -56,16 +54,18 @@ class ChatModel:
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.to_dict() generating_args = self.generating_args.to_dict()
generating_args.update(dict( generating_args.update(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"], dict(
temperature=temperature or generating_args["temperature"], do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
top_p=top_p or generating_args["top_p"], temperature=temperature or generating_args["temperature"],
top_k=top_k or generating_args["top_k"], top_p=top_p or generating_args["top_p"],
num_return_sequences=num_return_sequences or 1, top_k=top_k or generating_args["top_k"],
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], num_return_sequences=num_return_sequences or 1,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
pad_token_id=self.tokenizer.pad_token_id eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
)) pad_token_id=self.tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1: if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True generating_args["do_sample"] = True
@@ -81,7 +81,7 @@ class ChatModel:
gen_kwargs = dict( gen_kwargs = dict(
inputs=input_ids, inputs=input_ids,
generation_config=GenerationConfig(**generating_args), generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor() logits_processor=get_logits_processor(),
) )
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@@ -89,17 +89,15 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def chat( def chat(
self, self,
query: str, messages: Sequence[Dict[str, str]],
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
**input_kwargs tools: Optional[str] = None,
**input_kwargs,
) -> List[Response]: ) -> List[Response]:
r""" if not self.can_generate:
Args: query, history, system, **input_kwargs raise ValueError("The current model does not support `chat`.")
Returns: [(response_text, prompt_length, response_length)] * n (default n=1) gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs) generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode( response = self.tokenizer.batch_decode(
@@ -109,24 +107,29 @@ class ChatModel:
for i in range(len(response)): for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response( results.append(
response_text=response[i], Response(
response_length=response_length, response_text=response[i],
prompt_length=prompt_length, response_length=response_length,
finish_reason="stop" if len(eos_index) else "length" prompt_length=prompt_length,
)) finish_reason="stop" if len(eos_index) else "length",
)
)
return results return results
@torch.inference_mode() @torch.inference_mode()
def stream_chat( def stream_chat(
self, self,
query: str, messages: Sequence[Dict[str, str]],
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
**input_kwargs tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs) if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@@ -136,27 +139,19 @@ class ChatModel:
yield from streamer yield from streamer
@torch.inference_mode() @torch.inference_mode()
def get_scores( def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
self, if self.can_generate:
batch_input: List[str], raise ValueError("Cannot get scores using an auto-regressive model.")
**input_kwargs
) -> List[float]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda") device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer( inputs = self.tokenizer(
batch_input, batch_input,
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
pad_to_multiple_of=8,
return_tensors="pt", return_tensors="pt",
**kwargs add_special_tokens=True,
).to(device) ).to(device)
input_ids: torch.Tensor = inputs["input_ids"] input_ids: torch.Tensor = inputs["input_ids"]

View File

@@ -1,4 +1,6 @@
from llmtuner.data.loader import get_dataset from .loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset from .template import get_template_and_fix_tokenizer, templates
from llmtuner.data.template import get_template_and_fix_tokenizer from .utils import Role, split_dataset
from llmtuner.data.utils import split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]

View File

@@ -0,0 +1,133 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from .utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
content = []
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
content.append(examples[dataset_attr.prompt][i])
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else:
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
aligned_messages = []
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
}
)
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)

View File

@@ -0,0 +1,155 @@
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool{format_prompt}.\n"
"```\n"
)
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
items = (
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
)
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
action_match = re.search(regex, content)
if not action_match:
return content
tool_name = action_match.group(1).strip()
tool_input = action_match.group(2).strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return tool_name, json.dumps(arguments, ensure_ascii=False)
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default"
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
...
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
raise NotImplementedError
@dataclass
class EmptyFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
if isinstance(slot, str):
for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class FunctionFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
function = json.loads(content)
name = function["name"]
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except Exception:
name, arguments = "", ""
elements = []
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class ToolFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
if not len(tools):
return [""]
if self.tool_format == "default":
return [default_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [""]
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
if self.tool_format == "default":
return default_tool_extractor(content)
else:
raise NotImplementedError

View File

@@ -1,135 +1,122 @@
import inspect
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum
from llmtuner.data.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
logger = get_logger(__name__) logger = get_logger(__name__)
def get_dataset( def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments" data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]: ):
max_samples = data_args.max_samples logger.info("Loading dataset {}...".format(dataset_attr))
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
for dataset_attr in data_args.dataset_list: elif dataset_attr.load_from == "script":
logger.info("Loading dataset {}...".format(dataset_attr)) data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
if dataset_attr.load_from == "hf_hub": elif dataset_attr.load_from == "file":
data_path = dataset_attr.dataset_name data_files = []
data_name = dataset_attr.subset local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_files = None if os.path.isdir(local_path): # is directory
elif dataset_attr.load_from == "script": for file_name in os.listdir(local_path):
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) data_files.append(os.path.join(local_path, file_name))
data_name = dataset_attr.subset if data_path is None:
data_files = None data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif dataset_attr.load_from == "file": elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
data_path, data_name = None, None raise ValueError("File types should be identical.")
data_files: List[str] = [] elif os.path.isfile(local_path): # is file
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory data_files.append(local_path)
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
else: else:
raise NotImplementedError raise ValueError("File not found.")
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.file_sha1)
else:
raise NotImplementedError
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
name=data_name, name=data_name,
data_dir=data_dir,
data_files=data_files, data_files=data_files,
split=data_args.split, split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")) streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples))) num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return align_dataset(dataset, dataset_attr, data_args)
# convert dataset from sharegpt format to alpaca format
outputs = {"prompt": [], "query": [], "response": [], "history": []}
for msg_list in examples[dataset_attr.messages]:
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
if len(msg_list) == 0:
continue
msg_pairs = []
user_role, assistant_role = None, None
for idx in range(0, len(msg_list), 2):
if user_role is None and assistant_role is None:
user_role = msg_list[idx][dataset_attr.role]
assistant_role = msg_list[idx + 1][dataset_attr.role]
else:
if (
msg_list[idx][dataset_attr.role] != user_role
or msg_list[idx+1][dataset_attr.role] != assistant_role
):
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
if len(msg_pairs) != 0: def merge_dataset(
outputs["prompt"].append(msg_pairs[-1][0]) all_datasets: List[Union["Dataset", "IterableDataset"]],
outputs["query"].append("") data_args: "DataArguments",
outputs["response"].append(msg_pairs[-1][1]) training_args: "Seq2SeqTrainingArguments",
outputs["history"].append(msg_pairs[:-1]) ) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return outputs
if dataset_attr.formatting == "sharegpt": # convert format
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
dataset = dataset.map(
convert_format,
batched=True,
remove_columns=column_names,
**kwargs
)
else:
for column_name in ["prompt", "query", "response", "history"]: # align dataset
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.system_prompt: # add system prompt
system_prompt = dataset_attr.system_prompt
if data_args.streaming:
dataset = dataset.map(lambda _: {"system": system_prompt})
else:
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
@@ -141,8 +128,64 @@ def get_dataset(
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
probabilities=data_args.interleave_probs, probabilities=data_args.interleave_probs,
seed=data_args.seed, seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy.")
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset",
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset

119
src/llmtuner/data/parser.py Normal file
View File

@@ -0,0 +1,119 @@
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
""" extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
""" columns """
system: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
""" columns for the sharegpt format """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
tag_names = (
"role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_list.append(dataset_attr)
return dataset_list

View File

@@ -1,275 +1,269 @@
import os from functools import partial
import tiktoken
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from datasets import load_from_disk from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from .utils import Role
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
from .template import Template
logger = get_logger(__name__) logger = get_logger(__name__)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: def preprocess_pretrain_dataset(
for i in range(len(examples["prompt"])): examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
query, response = examples["prompt"][i], examples["response"][i] ) -> Dict[str, List[List[int]]]:
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query # build grouped texts with format `X1 X2 X3 ...`
history = examples["history"][i] if "history" in examples else None text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
system = examples["system"][i] if "system" in examples else None tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
yield query, response, history, system concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]: def preprocess_supervised_dataset(
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len))) examples: Dict[str, List[Any]],
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", ) -> Dict[str, List[List[int]]]:
stage: Literal["pt", "sft", "rm", "ppo"] # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
) -> Union["Dataset", "IterableDataset"]: # for multiturn examples, we only mask the prompt part in each prompt-response pair.
template = get_template_and_fix_tokenizer(data_args.template, tokenizer) model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if data_args.train_on_prompt and template.efficient_eos: for i in range(len(examples["prompt"])):
raise ValueError("Current template does not support `train_on_prompt`.") if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: messages = examples["prompt"][i] + examples["response"][i]
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
# make sure the saved tokenizer is the same as the original one
if hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
source_len, target_len = len(source_ids), len(target_ids)
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
source_ids = source_ids[:max_source_len]
if target_len > max_target_len:
target_ids = target_ids[:max_target_len]
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
labels = labels[:data_args.cutoff_len]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], [] input_ids, labels = [], []
for query, response, history, system in construct_example(examples): for turn_idx, (source_ids, target_ids) in enumerate(
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): template.encode_multiturn(
continue tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( input_ids += source_ids + target_ids
tokenizer, query, response, history, system labels += source_mask + target_ids
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos: if template.efficient_eos:
input_ids += [tokenizer.eos_token_id] input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
total_length = len(input_ids) model_inputs["input_ids"].append(input_ids)
block_size = data_args.cutoff_len model_inputs["attention_mask"].append([1] * len(input_ids))
# we drop the small remainder, and if the total_length < block_size, we exclude this batch model_inputs["labels"].append(labels)
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len return model_inputs
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
for source_ids, target_ids in template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i]
):
if data_args.train_on_prompt:
source_mask = source_ids
elif len(input_ids) != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size) model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size]) model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples): def preprocess_unsupervised_dataset(
if not (isinstance(query, str) and query != ""): examples: Dict[str, List[Any]],
continue tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system) for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
continue
if template.efficient_eos: if len(examples["response"][i]) == 1:
labels += [tokenizer.eos_token_id] messages = examples["prompt"][i] + examples["response"][i]
else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
if len(input_ids) > data_args.cutoff_len: input_ids, labels = template.encode_oneturn(
input_ids = input_ids[:data_args.cutoff_len] tokenizer,
if len(labels) > data_args.cutoff_len: messages,
labels = labels[:data_args.cutoff_len] examples["system"][i],
examples["tools"][i],
model_inputs["input_ids"].append(input_ids) data_args.cutoff_len,
model_inputs["attention_mask"].append([1] * len(input_ids)) data_args.reserved_label_len,
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
rejected_ids = rejected_ids[:max_target_len]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
return load_from_disk(data_args.cache_path)
with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
) )
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): if template.efficient_eos:
if training_args.should_save: labels += [tokenizer.eos_token_id]
dataset.save_to_disk(data_args.cache_path)
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
if training_args.should_log: model_inputs["input_ids"].append(input_ids)
try: model_inputs["attention_mask"].append([1] * len(input_ids))
print_function(next(iter(dataset))) model_inputs["labels"].append(labels)
except StopIteration:
raise RuntimeError("Empty dataset!")
return dataset return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(
"labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
)
)
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

File diff suppressed because it is too large Load Diff

View File

@@ -1,25 +1,27 @@
import hashlib import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import TrainingArguments from transformers import TrainingArguments
from llmtuner.hparams import DataArguments from llmtuner.hparams import DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
EXT2TYPE = { @unique
"arrow": "arrow", class Role(str, Enum):
"csv": "csv", USER = "user"
"json": "json", ASSISTANT = "assistant"
"jsonl": "json", SYSTEM = "system"
"parquet": "parquet", FUNCTION = "function"
"txt": "text" OBSERVATION = "observation"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
@@ -37,13 +39,18 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len
return max_source_len, max_target_len
def split_dataset( def split_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]: ) -> Dict[str, "Dataset"]:
if training_args.do_train: if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming: if data_args.streaming:
val_set = dataset.take(int(data_args.val_size)) val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size)) train_set = dataset.skip(int(data_args.val_size))
@@ -57,5 +64,5 @@ def split_dataset(
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset} return {"train_dataset": dataset}
else: # do_eval or do_predict else: # do_eval or do_predict
return {"eval_dataset": dataset} return {"eval_dataset": dataset}

View File

@@ -1 +1,4 @@
from llmtuner.eval.evaluator import Evaluator from .evaluator import Evaluator
__all__ = ["Evaluator"]

View File

@@ -1,41 +1,34 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py # Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import json
import torch
import inspect import inspect
import tiktoken import json
import numpy as np import os
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np
import torch
from datasets import load_dataset from datasets import load_dataset
from tqdm import tqdm, trange
from transformers.utils import cached_file from transformers.utils import cached_file
from llmtuner.data.template import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template from ..extras.constants import CHOICES, SUBJECTS
from llmtuner.extras.constants import CHOICES, SUBJECTS from ..hparams import get_eval_args
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer from ..model import dispatch_model, load_model_and_tokenizer
from .template import get_eval_template
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = self._encode_choices() self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
def _encode_choices(self) -> List[int]: ]
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
@@ -46,16 +39,11 @@ class Evaluator:
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None: def eval(self) -> None:
if "token" in inspect.signature(cached_file).parameters:
kwargs = {"token": self.model_args.hf_hub_token}
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
mapping = cached_file( mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task), path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json", filename="mapping.json",
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
**kwargs token=self.model_args.hf_hub_token,
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, "r", encoding="utf-8") as f:
@@ -65,37 +53,45 @@ class Evaluator:
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {} results = {}
for subject in pbar: for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset( dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task), path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject, name=subject,
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode, download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token token=self.model_args.hf_hub_token,
**kwargs,
) )
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False): for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) support_set = (
query, resp, history = self.eval_template.format_example( dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
)
messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i], target_data=dataset[self.data_args.split][i],
support_set=support_set, support_set=support_set,
subject_name=categorys[subject]["name"], subject_name=categorys[subject]["name"],
use_history=self.template.use_history
) )
input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp)
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False): input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(messages[-1]["content"])
for i in trange(
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
):
batch_input = self.tokenizer.pad( batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt" inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device) ).to(self.model.device)
preds = self.batch_inference(batch_input) preds = self.batch_inference(batch_input)
outputs += preds outputs += preds
corrects = (np.array(outputs) == np.array(labels)) corrects = np.array(outputs) == np.array(labels)
category_name = categorys[subject]["category"] category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0) category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
@@ -105,10 +101,13 @@ class Evaluator:
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join([ score_info = "\n".join(
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) [
for category_name, category_correct in category_corrects.items() if len(category_correct) "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
]) for category_name, category_correct in category_corrects.items()
if len(category_correct)
]
)
print(score_info) print(score_info)
if self.eval_args.save_dir is not None: if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False) os.makedirs(self.eval_args.save_dir, exist_ok=False)

View File

@@ -1,7 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.extras.constants import CHOICES from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset
@@ -9,60 +11,39 @@ if TYPE_CHECKING:
@dataclass @dataclass
class EvalTemplate: class EvalTemplate:
system: str system: str
choice: str choice: str
answer: str answer: str
prefix: str prefix: str
def parse_example( def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"] return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example( def format_example(
self, self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
target_data: Dict[str, str], ) -> List[Dict[str, str]]:
support_set: "Dataset", messages = []
subject_name: str, for k in range(len(support_set)):
use_history: bool prompt, response = self.parse_example(support_set[k])
) -> Tuple[str, str, List[Tuple[str, str]]]: messages.append({"role": Role.USER, "content": prompt})
query, resp = self.parse_example(target_data) messages.append({"role": Role.ASSISTANT, "content": response})
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history): prompt, response = self.parse_example(target_data)
temp = history.pop(0) messages.append({"role": Role.USER, "content": prompt})
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) messages.append({"role": Role.ASSISTANT, "content": response})
else: messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
query = self.system.format(subject=subject_name) + query return messages
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
eval_templates: Dict[str, EvalTemplate] = {} eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template( def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
name: str, eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
system: str,
choice: str,
answer: str,
prefix: str
) -> None:
eval_templates[name] = EvalTemplate(
system=system,
choice=choice,
answer=answer,
prefix=prefix
)
def get_eval_template(name: str) -> EvalTemplate: def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None) eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name) assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template return eval_template
@@ -73,7 +54,7 @@ register_eval_template(
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\nAnswer: ", answer="\nAnswer: ",
prefix=" " prefix=" ",
) )
@@ -82,5 +63,5 @@ register_eval_template(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:", answer="\n答案:",
prefix="\n" prefix="\n",
) )

View File

@@ -1,56 +1,38 @@
import os
import json import json
import os
import time import time
from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.modeling_utils import custom_object_save, unwrap_model from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from .constants import LOG_FILE_NAME
from .logging import get_logger
from .misc import fix_valuehead_checkpoint
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainerControl, TrainerState, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__) logger = get_logger(__name__)
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: class FixValueHeadModelCallback(TrainerCallback):
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir)
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after a checkpoint save. Event called after a checkpoint save.
""" """
if args.should_save: if args.should_save:
_save_model_with_valuehead( fix_valuehead_checkpoint(
model=unwrap_model(kwargs.pop("model")), model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
) )
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
self.runner = runner self.runner = runner
self.in_training = False self.in_training = False
@@ -116,7 +98,9 @@ class LogCallback(TrainerCallback):
self.cur_steps = 0 self.cur_steps = 0
self.max_steps = 0 self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
r""" r"""
Event called after a successful prediction. Event called after a successful prediction.
""" """
@@ -142,18 +126,22 @@ class LogCallback(TrainerCallback):
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time remaining_time=self.remaining_time,
) )
if self.runner is not None: if self.runner is not None:
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( logger.info(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
)) logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
)
)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n") f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r""" r"""
Event called after a prediction step. Event called after a prediction step.
""" """

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import sys
import logging import logging
import sys
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
@@ -27,8 +27,7 @@ def get_logger(name: str) -> logging.Logger:
Gets a standard logger with a stream hander to stdout. Gets a standard logger with a stream hander to stdout.
""" """
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
datefmt="%m/%d/%Y %H:%M:%S"
) )
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter) handler.setFormatter(formatter)

View File

@@ -1,35 +1,45 @@
import gc import gc
import os import os
import sys from typing import TYPE_CHECKING, Dict, Tuple
import torch
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
import torch
from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_xpu_available,
)
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try: try:
from transformers.utils import ( _is_bf16_available = is_torch_bf16_gpu_available()
is_torch_bf16_cpu_available, except Exception:
is_torch_bf16_gpu_available, _is_bf16_available = False
is_torch_cuda_available,
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
try:
_is_bf16_available = torch.cuda.is_bf16_supported()
except:
_is_bf16_available = False
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import HfArgumentParser from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
class AverageMeter: class AverageMeter:
r""" r"""
Computes and stores the average and current value. Computes and stores the average and current value.
""" """
def __init__(self): def __init__(self):
self.reset() self.reset()
@@ -68,6 +78,76 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param return trainable_params, all_param
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
return
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace("pretrained_model.", "")] = param
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
def get_current_device() -> torch.device:
r"""
Gets the current available device.
"""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
device = "cpu"
return torch.device(device)
def get_device_count() -> int:
return torch.cuda.device_count()
def get_logits_processor() -> "LogitsProcessorList": def get_logits_processor() -> "LogitsProcessorList":
r""" r"""
Gets logits processor that removes NaN and Inf logits. Gets logits processor that removes NaN and Inf logits.
@@ -89,17 +169,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32 return torch.float32
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""
Collects GPU memory. Collects GPU memory.
@@ -115,12 +184,11 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
return return
try: try:
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download( model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path, model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
revision=revision,
cache_dir=model_args.cache_dir
) )
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError("Please install modelscope via `pip install modelscope -U`")

View File

@@ -2,59 +2,52 @@ import importlib.metadata
import importlib.util import importlib.util
def is_package_available(name: str) -> bool: def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str: def _get_package_version(name: str) -> str:
try: try:
return importlib.metadata.version(name) return importlib.metadata.version(name)
except: except Exception:
return "0.0.0" return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_requests_available = is_package_available("requests")
_rouge_available = is_package_available("rouge_chinese")
_starlette_available = is_package_available("sse_starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble(): def is_fastapi_availble():
return _fastapi_available return _is_package_available("fastapi")
def is_flash_attn2_available(): def is_flash_attn2_available():
return _flash_attn2_available return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
def is_jieba_available(): def is_jieba_available():
return _jieba_available return _is_package_available("jieba")
def is_matplotlib_available(): def is_matplotlib_available():
return _matplotlib_available return _is_package_available("matplotlib")
def is_nltk_available(): def is_nltk_available():
return _nltk_available return _is_package_available("nltk")
def is_requests_available(): def is_requests_available():
return _requests_available return _is_package_available("requests")
def is_rouge_available(): def is_rouge_available():
return _rouge_available return _is_package_available("rouge_chinese")
def is_starlette_available(): def is_starlette_available():
return _starlette_available return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available(): def is_uvicorn_available():
return _uvicorn_available return _is_package_available("uvicorn")

View File

@@ -1,224 +1,197 @@
import math import math
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class LlamaShiftShortAttention(LlamaAttention): def llama_torch_attn_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
def forward( query_states = self.q_proj(hidden_states)
self, key_states = self.k_proj(hidden_states)
hidden_states: torch.Tensor, value_states = self.v_proj(hidden_states)
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2]
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if past_key_value is not None:
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if past_key_value is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if past_key_value is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if past_key_value is not None: # reuse k, v, self_attention key_states = repeat_kv(key_states, self.num_key_value_groups)
key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = repeat_kv(value_states, self.num_key_value_groups)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
if getattr(self, "num_key_value_groups"): def shift(state: torch.Tensor) -> torch.Tensor:
key_states = repeat_kv(key_states, self.num_key_value_groups) state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
value_states = repeat_kv(value_states, self.num_key_value_groups) state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
if getattr(self.config, "group_size_ratio", None) and self.training: # shift dim=2,
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# cast to half precision
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(self.config.torch_dtype)
key_states = key_states.to(self.config.torch_dtype)
value_states = value_states.to(self.config.torch_dtype)
if getattr(self, "num_key_value_groups", None):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
if attention_mask is not None:
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
# -q_len: assumes left padding when q_len != kv_len
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
attn_output_unpad = flash_attn_varlen_func(
unpadded_q,
unpadded_k,
unpadded_v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True,
) )
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
dropout_rate = self.attention_dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else: else:
attn_output = flash_attn_func( target_dtype = self.q_proj.weight.dtype
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
) )
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) if attention_mask is not None:
attn_output = torch.cat(( attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output: torch.Tensor = self._flash_attention_forward(
attn_output = self.o_proj(attn_output) query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
if not output_attentions: if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_weights = None attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
return attn_output, attn_weights, past_key_value attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Disable the transformation of the attention mask in LlamaModel as flash attention def apply_llama_patch() -> None:
# takes a boolean padding_mask. Fills in the past kv length for use in forward. LlamaAttention.forward = llama_torch_attn_forward
def _prepare_decoder_attention_mask( LlamaFlashAttention2.forward = llama_flash_attn_forward
self,
attention_mask: torch.Tensor,
input_shape: torch.Tensor,
inputs_embeds: torch.Tensor,
past_key_values_length: int
) -> torch.Tensor:
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
return attention_mask

View File

@@ -0,0 +1,38 @@
import torch
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
def patch_mixtral_replace_moe_impl() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -1,11 +1,13 @@
import os
import math
import json import json
import math
import os
from typing import List, Optional from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger from .logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available from .packages import is_matplotlib_available
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
""" """
last = scalars[0] last = scalars[0]
smoothed = list() smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars: for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val) smoothed.append(smoothed_val)
@@ -29,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)

View File

@@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args
__all__ = [
"DataArguments",
"EvaluationArguments",
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"get_eval_args",
"get_infer_args",
"get_train_args",
]

View File

@@ -1,33 +1,5 @@
import os
import json
from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional
DATA_CONFIG = "dataset_info.json"
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None
subset: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
role: Optional[str] = "from"
content: Optional[str] = "value"
def __repr__(self) -> str:
return self.dataset_name
@dataclass @dataclass
@@ -35,85 +7,84 @@ class DataArguments:
r""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: Optional[str] = field( template: Optional[str] = field(
default=None, default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."} metadata={"help": "Which template to use for constructing prompts in training and inference."},
) )
dataset: Optional[str] = field( dataset: Optional[str] = field(
default=None, default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."} metadata={"help": "Path to the folder containing the datasets."},
) )
split: Optional[str] = field( split: Optional[str] = field(
default="train", default="train",
metadata={"help": "Which dataset split to use for training and evaluation."} metadata={"help": "Which dataset split to use for training and evaluation."},
) )
cutoff_len: Optional[int] = field( cutoff_len: Optional[int] = field(
default=1024, default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."} metadata={"help": "The cutoff length of the model inputs after tokenization."},
) )
reserved_label_len: Optional[int] = field( reserved_label_len: Optional[int] = field(
default=1, default=1,
metadata={"help": "The maximum length reserved for label after tokenization."} metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
) )
train_on_prompt: Optional[bool] = field( train_on_prompt: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."} metadata={"help": "Whether to disable the mask on the prompt or not."},
) )
streaming: Optional[bool] = field( streaming: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Enable dataset streaming."} metadata={"help": "Enable dataset streaming."},
) )
buffer_size: Optional[int] = field( buffer_size: Optional[int] = field(
default=16384, default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."} metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
) )
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."} metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
) )
interleave_probs: Optional[str] = field( interleave_probs: Optional[str] = field(
default=None, default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."} metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."} metadata={"help": "Overwrite the cached training and evaluation sets."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the preprocessing."} metadata={"help": "The number of processes to use for the preprocessing."},
) )
max_samples: Optional[int] = field( max_samples: Optional[int] = field(
default=None, default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
) )
eval_num_beams: Optional[int] = field( eval_num_beams: Optional[int] = field(
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
) )
ignore_pad_token_for_loss: Optional[bool] = field( ignore_pad_token_for_loss: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} metadata={
) "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
system_prompt: Optional[str] = field( },
default=None,
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
) )
val_size: Optional[float] = field( val_size: Optional[float] = field(
default=0, default=0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
) )
sft_packing: Optional[bool] = field( sft_packing: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
) )
cache_path: Optional[str] = field( cache_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to save or load the preprocessed datasets."} metadata={"help": "Path to save or load the preprocessed datasets."},
) )
def __post_init__(self): def __post_init__(self):
@@ -125,55 +96,3 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.streaming and self.cache_path:
raise ValueError("`cache_path` is incompatible with `streaming`.")
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if self.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names):
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)

View File

@@ -1,6 +1,6 @@
import os import os
from typing import Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional
from datasets import DownloadMode from datasets import DownloadMode
@@ -10,46 +10,39 @@ class EvaluationArguments:
r""" r"""
Arguments pertaining to specify the evaluation parameters. Arguments pertaining to specify the evaluation parameters.
""" """
task: str = field( task: str = field(
metadata={"help": "Name of the evaluation task."} metadata={"help": "Name of the evaluation task."},
) )
task_dir: Optional[str] = field( task_dir: Optional[str] = field(
default="evaluation", default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."} metadata={"help": "Path to the folder containing the evaluation datasets."},
) )
batch_size: Optional[int] = field( batch_size: Optional[int] = field(
default=4, default=4,
metadata={"help": "The batch size per GPU for evaluation."} metadata={"help": "The batch size per GPU for evaluation."},
) )
seed: Optional[int] = field( seed: Optional[int] = field(
default=42, default=42,
metadata={"help": "Random seed to be used with data loaders."} metadata={"help": "Random seed to be used with data loaders."},
) )
lang: Optional[Literal["en", "zh"]] = field( lang: Optional[Literal["en", "zh"]] = field(
default="en", default="en",
metadata={"help": "Language used at evaluation."} metadata={"help": "Language used at evaluation."},
) )
n_shot: Optional[int] = field( n_shot: Optional[int] = field(
default=5, default=5,
metadata={"help": "Number of examplars for few-shot learning."} metadata={"help": "Number of examplars for few-shot learning."},
) )
save_dir: Optional[str] = field( save_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to save the evaluation results."} metadata={"help": "Path to save the evaluation results."},
) )
download_mode: Optional[DownloadMode] = field( download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS, default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."} metadata={"help": "Download mode used for the evaluation datasets."},
) )
def __post_init__(self): def __post_init__(self):
task_available = []
for folder in os.listdir(self.task_dir):
if os.path.isdir(os.path.join(self.task_dir, folder)):
task_available.append(folder)
if self.task not in task_available:
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
if self.save_dir is not None and os.path.exists(self.save_dir): if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.") raise ValueError("`save_dir` already exists, use another one.")

View File

@@ -1,6 +1,6 @@
import json import json
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Literal, Optional
@dataclass @dataclass
@@ -8,19 +8,23 @@ class FreezeArguments:
r""" r"""
Arguments pertaining to the freeze (partial-parameter) training. Arguments pertaining to the freeze (partial-parameter) training.
""" """
name_module_trainable: Optional[str] = field( name_module_trainable: Optional[str] = field(
default="mlp", default=None,
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ metadata={
Use commas to separate multiple modules. \ "help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \ Use commas to separate multiple modules. \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \ Use "all" to specify all the available modules. \
Qwen choices: [\"mlp\", \"attn\"], \ LLaMA choices: ["mlp", "self_attn"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \ BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Others choices: the same as LLaMA."} Qwen choices: ["mlp", "attn"], \
InternLM2 choices: ["feed_forward", "attention"], \
Others choices: the same as LLaMA."""
},
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."} metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
) )
@@ -29,35 +33,53 @@ class LoraArguments:
r""" r"""
Arguments pertaining to the LoRA training. Arguments pertaining to the LoRA training.
""" """
additional_target: Optional[str] = field( additional_target: Optional[str] = field(
default=None, default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} metadata={
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
},
) )
lora_alpha: Optional[float] = field( lora_alpha: Optional[int] = field(
default=None, default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."} metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
) )
lora_dropout: Optional[float] = field( lora_dropout: Optional[float] = field(
default=0.1, default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."} metadata={"help": "Dropout rate for the LoRA fine-tuning."},
) )
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
default=8, default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
) )
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default=None, default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ metadata={
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ "help": """Name(s) of target modules to apply LoRA. \
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \ Use commas to separate multiple modules. \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Use "all" to specify all the available modules. \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Others choices: the same as LLaMA."} Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
Others choices: the same as LLaMA."""
},
) )
resume_lora_training: Optional[bool] = field( lora_bf16_mode: Optional[bool] = field(
default=True, default=False,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
)
use_rslora: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
)
use_dora: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}
)
create_new_adapter: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
) )
@@ -66,61 +88,70 @@ class RLHFArguments:
r""" r"""
Arguments pertaining to the PPO and DPO training. Arguments pertaining to the PPO and DPO training.
""" """
dpo_beta: Optional[float] = field( dpo_beta: Optional[float] = field(
default=0.1, default=0.1,
metadata={"help": "The beta parameter for the DPO loss."} metadata={"help": "The beta parameter for the DPO loss."},
)
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_ftx: Optional[float] = field(
default=0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
) )
ppo_buffer_size: Optional[int] = field( ppo_buffer_size: Optional[int] = field(
default=1, default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."} metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
) )
ppo_epochs: Optional[int] = field( ppo_epochs: Optional[int] = field(
default=4, default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."} metadata={"help": "The number of epochs to perform in a PPO optimization step."},
) )
ppo_logger: Optional[str] = field( ppo_logger: Optional[str] = field(
default=None, default=None,
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."} metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
) )
ppo_score_norm: Optional[bool] = field( ppo_score_norm: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Use score normalization in PPO training."} metadata={"help": "Use score normalization in PPO training."},
) )
ppo_target: Optional[float] = field( ppo_target: Optional[float] = field(
default=6.0, default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."} metadata={"help": "Target KL value for adaptive KL control in PPO training."},
) )
ppo_whiten_rewards: Optional[bool] = field( ppo_whiten_rewards: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."} metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
) )
ref_model: Optional[str] = field( ref_model: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."} metadata={"help": "Path to the reference model used for the PPO or DPO training."},
) )
ref_model_checkpoint: Optional[str] = field( ref_model_adapters: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."} metadata={"help": "Path to the adapters of the reference model."},
) )
ref_model_quantization_bit: Optional[int] = field( ref_model_quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reference model."} metadata={"help": "The number of bits to quantize the reference model."},
) )
reward_model: Optional[str] = field( reward_model: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."} metadata={"help": "Path to the reward model used for the PPO training."},
) )
reward_model_checkpoint: Optional[str] = field( reward_model_adapters: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."} metadata={"help": "Path to the adapters of the reward model."},
) )
reward_model_quantization_bit: Optional[int] = field( reward_model_quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reward model."} metadata={"help": "The number of bits to quantize the reward model."},
) )
reward_model_type: Optional[Literal["lora", "full", "api"]] = field( reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
default="lora", default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."} metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
) )
@@ -129,33 +160,26 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."} metadata={"help": "Which stage will be performed in training."},
) )
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora", default="lora",
metadata={"help": "Which fine-tuning method to use."} metadata={"help": "Which fine-tuning method to use."},
) )
upcast_layernorm: Optional[bool] = field( use_llama_pro: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."} metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
) )
neft_alpha: Optional[float] = field( disable_version_checking: Optional[bool] = field(
default=0, default=False,
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} metadata={"help": "Whether or not to disable version checking."},
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}
) )
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."} metadata={"help": "Whether or not to save the training loss curves."},
) )
def __post_init__(self): def __post_init__(self):
@@ -165,21 +189,22 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
return arg return arg
self.name_module_trainable = split_arg(self.name_module_trainable) self.name_module_trainable = split_arg(self.name_module_trainable)
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0) self.lora_alpha = self.lora_alpha or self.lora_rank * 2
self.lora_target = split_arg(self.lora_target) self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target) self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None: if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.") raise ValueError("`reward_model` is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.") raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""

View File

@@ -1,5 +1,5 @@
from typing import Any, Dict, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
@dataclass @dataclass
@@ -7,41 +7,44 @@ class GeneratingArguments:
r""" r"""
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
) )
temperature: Optional[float] = field( temperature: Optional[float] = field(
default=0.95, default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."} metadata={"help": "The value used to modulate the next token probabilities."},
) )
top_p: Optional[float] = field( top_p: Optional[float] = field(
default=0.7, default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
},
) )
top_k: Optional[int] = field( top_k: Optional[int] = field(
default=50, default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
) )
num_beams: Optional[int] = field( num_beams: Optional[int] = field(
default=1, default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."} metadata={"help": "Number of beams for beam search. 1 means no beam search."},
) )
max_length: Optional[int] = field( max_length: Optional[int] = field(
default=512, default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
) )
max_new_tokens: Optional[int] = field( max_new_tokens: Optional[int] = field(
default=512, default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
) )
repetition_penalty: Optional[float] = field( repetition_penalty: Optional[float] = field(
default=1.0, default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
) )
length_penalty: Optional[float] = field( length_penalty: Optional[float] = field(
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:

View File

@@ -1,5 +1,5 @@
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
@dataclass @dataclass
@@ -7,57 +7,119 @@ class ModelArguments:
r""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
""" """
model_name_or_path: str = field( model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from \ metadata={
huggingface.co/models or modelscope.cn/models."} "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
use_fast_tokenizer: Optional[bool] = field( use_fast_tokenizer: Optional[bool] = field(
default=True, default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
) )
split_special_tokens: Optional[bool] = field( split_special_tokens: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
model_revision: Optional[str] = field( model_revision: Optional[str] = field(
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model."} metadata={"help": "The number of bits to quantize the model."},
) )
quantization_type: Optional[Literal["fp4", "nf4"]] = field( quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4", default="nf4",
metadata={"help": "Quantization data type to use in int4 training."} metadata={"help": "Quantization data type to use in int4 training."},
) )
double_quantization: Optional[bool] = field( double_quantization: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."} metadata={"help": "Whether or not to use double quantization in int4 training."},
) )
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."} metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
) )
flash_attn: Optional[bool] = field( flash_attn: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."} metadata={"help": "Enable FlashAttention-2 for faster training."},
) )
shift_attn: Optional[bool] = field( shift_attn: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
disable_gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
)
upcast_lmhead_output: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
) )
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."} metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: Optional[int] = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
print_param_status: Optional[bool] = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
) )
def __post_init__(self): def __post_init__(self):
@@ -67,10 +129,14 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return asdict(self) return asdict(self)

View File

@@ -0,0 +1,275 @@
import logging
import os
import sys
from typing import Any, Dict, Optional, Tuple
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
logger = get_logger(__name__)
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _check_dependencies(disabled: bool) -> None:
if disabled:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if unknown_args:
print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if (
training_args.do_train
and finetuning_args.finetuning_type == "freeze"
and finetuning_args.name_module_trainable is None
):
raise ValueError("Please specify `name_module_trainable` in Freeze training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if finetuning_args.use_dora:
if model_args.quantization_bit is not None:
raise ValueError("DoRA does not support quantization.")
if model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info(
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
)
)
if (
finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None
):
logger.warning(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
)
# Post-process model arguments
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info(
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
str(model_args.compute_dtype),
)
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@@ -1,5 +1,5 @@
# Level: loader > adapter > parser, utils from .loader import load_model_and_tokenizer
from .utils import dispatch_model, load_valuehead_params
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args __all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"]
from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params

View File

@@ -1,23 +1,24 @@
import torch
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger import torch
from llmtuner.model.utils import find_all_linear_modules from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
from .utils import find_all_linear_modules, find_expanded_modules
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
def init_adapter( def init_adapter(
model: "PreTrainedModel", model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Initializes the adapters. Initializes the adapters.
@@ -27,8 +28,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32. Note that the trainable parameters must be cast to float32.
""" """
if (not is_trainable) and model_args.checkpoint_dir is None: if (not is_trainable) and model_args.adapter_name_or_path is None:
logger.info("Checkpoint is not found at evaluation, load the original model.") logger.info("Adapter is not found at evaluation, load the base model.")
return model return model
if finetuning_args.finetuning_type == "full" and is_trainable: if finetuning_args.finetuning_type == "full" and is_trainable:
@@ -44,65 +45,115 @@ def init_adapter(
) )
if not num_layers: if not num_layers:
raise ValueError("Current model does not support freeze tuning.") raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] if finetuning_args.use_llama_pro:
else: # fine-tuning the first n layers if num_layer_trainable < 0 if num_layers % finetuning_args.num_layer_trainable != 0:
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.num_layer_trainable
)
)
stride = num_layers // finetuning_args.num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
freeze_modules = {"all"}
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
trainable_layers = [] trainable_layers = []
for module_name in finetuning_args.name_module_trainable: for module_name in finetuning_args.name_module_trainable:
if module_name not in freeze_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
)
for idx in trainable_layer_ids: for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name)) trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers): if any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA") logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
checkpoint_to_resume = None adapter_to_resume = None
if model_args.checkpoint_dir is not None: if model_args.adapter_name_or_path is not None:
is_mergeable = True is_mergeable = True
if getattr(model, "quantization_method", None) == "gptq": if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.checkpoint_dir) == 1, "GPTQ quantized model only accepts a single checkpoint." assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False is_mergeable = False
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): if is_deepspeed_zero3_enabled():
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
else: is_mergeable = False
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge: if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
model = PeftModel.from_pretrained(model, checkpoint) adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
adapter_to_merge = model_args.adapter_name_or_path
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter)
model = model.merge_and_unload() model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0: if len(adapter_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if checkpoint_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable) model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model) target_modules = find_all_linear_modules(model)
else: else:
target_modules = finetuning_args.lora_target target_modules = finetuning_args.lora_target
lora_config = LoraConfig( if finetuning_args.use_llama_pro:
task_type=TaskType.CAUSAL_LM, target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=target_modules,
modules_to_save=finetuning_args.additional_target
)
model = get_peft_model(model, lora_config)
if model_args.checkpoint_dir is not None: if finetuning_args.use_dora:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) if getattr(model, "quantization_method", None):
raise ValueError("DoRA is currently not compatible with quantized models.")
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
}
if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model return model

View File

@@ -1,56 +1,31 @@
import os from typing import TYPE_CHECKING, Optional, Tuple
import math
import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import ( from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
AutoConfig, from transformers.integrations import is_deepspeed_zero3_enabled
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
try: from ..extras.logging import get_logger
from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1 from .adapter import init_adapter
from transformers.deepspeed import is_deepspeed_zero3_enabled from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
def load_model_and_tokenizer( def load_model_and_tokenizer(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False add_valuehead: Optional[bool] = False,
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]: ) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r""" r"""
Loads pretrained model and tokenizer. Loads pretrained model and tokenizer.
@@ -63,176 +38,95 @@ def load_model_and_tokenizer(
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
"revision": model_args.model_revision, "revision": model_args.model_revision,
"token": model_args.hf_hub_token "token": model_args.hf_hub_token,
} }
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens, split_special_tokens=model_args.split_special_tokens,
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow padding_side="right",
**config_kwargs **config_kwargs,
) )
patch_tokenizer(tokenizer)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.") patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) model = None
if is_trainable and model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
# Fix tokenizer (for ChatGLM2 and ChatGLM3) unsloth_kwargs = {
if getattr(config, "model_type", None) == "chatglm": "model_name": model_args.model_name_or_path,
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) "max_seq_length": model_args.model_max_length,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
}
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
# Set model dtype if model_args.adapter_name_or_path:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.adapter_name_or_path = None
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) logger.warning("Unsloth does not support loading adapters.")
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen) if model is None:
if getattr(config, "model_type", None) == "qwen": model = AutoModelForCausalLM.from_pretrained(
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: model_args.model_name_or_path,
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs,
)
# Set RoPE scaling patch_model(model, tokenizer, model_args, is_trainable)
if model_args.rope_scaling is not None: register_autoclass(config, model, tokenizer)
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
if is_flash_attn2_available():
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) == "llama":
setattr(config, "group_size_ratio", 0.25)
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using gptq or awq)
if getattr(config, "quantization_config", None):
if model_args.quantization_bit is not None: # remove bnb quantization
model_args.quantization_bit = None
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
# Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable) model = init_adapter(model, model_args, finetuning_args, is_trainable)
# Prepare model with valuehead for RLHF
if add_valuehead: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name]) patch_valuehead_model(model)
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
vhead_path = ( if model_args.adapter_name_or_path is not None:
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path vhead_path = model_args.adapter_name_or_path[-1]
) else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args) vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False)
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
model.eval() model.eval()
else: else:
model.train() model.train()
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( logger.info(
trainable_params, all_param, 100 * trainable_params / all_param "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
)) trainable_params, all_param, 100 * trainable_params / all_param
)
)
if not is_trainable: if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
if model_args.print_param_status:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
return model, tokenizer return model, tokenizer

View File

@@ -1,205 +0,0 @@
import os
import torch
import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments,
GeneratingArguments
)
logger = get_logger(__name__)
_TRAIN_ARGS = [
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args)
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args)
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments
data_args.init_for_training(training_args.seed)
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])):
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
_verify_model_args(model_args, finetuning_args)
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
))
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
))
# postprocess model_args
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@@ -0,0 +1,334 @@
import math
import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
if model_args.flash_attn:
if is_flash_attn2_available():
config_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention2 is not installed.")
config_kwargs["attn_implementation"] = None
else:
config_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
def _configure_longlora(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
quantization_config["use_exllama"] = False # disable exllama
if quantization_config.get("quant_method", None) == "aqlm":
quantization_config["bits"] = 2
logger.info(
"Loading {}-bit {}-quantized model.".format(
quantization_config.get("bits", "?"), quantization_config.get("quant_method", None)
)
)
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
config_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
logger.info("Upcasting layernorm weights in float32.")
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(model_args, config_kwargs)
if model_args.rope_scaling is not None:
_configure_rope(config, model_args, is_trainable)
if is_trainable and model_args.shift_attn:
_configure_longlora(config)
_configure_quantization(config, tokenizer, model_args, config_kwargs)
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
) -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer)
if is_trainable:
_prepare_model_for_training(model, model_args)
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if is_trainable:
patch_mixtral_replace_moe_impl()
try:
model.add_model_tags(["llama-factory"])
except Exception:
logger.warning("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
self.pretrained_model.tie_weights()
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))

View File

@@ -1,17 +1,19 @@
import torch
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, List
import torch
from transformers import PreTrainedModel
from transformers.utils import cached_file from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments from ..extras.misc import get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers import PretrainedConfig, PreTrainedTokenizer
from llmtuner.hparams import DataArguments
from ..hparams import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -19,27 +21,32 @@ logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r""" r"""
Dispatches a pre-trained model to GPUs with balanced memory. Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
""" """
if getattr(model, "quantization_method", None): # already set on current device if getattr(model, "quantization_method", None): # already set on current device
return model return model
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm": if (
torch.cuda.device_count() > 1
and isinstance(model, PreTrainedModel)
and model._no_split_modules is not None
and model.config.model_type != "chatglm"
):
from accelerate import dispatch_model from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory from accelerate.utils import get_balanced_memory, infer_auto_device_map
if model._no_split_modules is None: kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs) max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map. # Make sure tied weights are tied before creating the device map.
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map) device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs)
else: else:
return model.cuda() return model.to(device=get_current_device())
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
@@ -51,6 +58,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
linear_cls = torch.nn.Linear linear_cls = torch.nn.Linear
elif quantization_method == "bitsandbytes": elif quantization_method == "bitsandbytes":
import bitsandbytes as bnb import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else: else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method)) raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
@@ -61,123 +69,72 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names))) logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names) return list(module_names)
def get_modelcard_args( def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
model_args: "ModelArguments", r"""
data_args: "DataArguments", Finds the modules in the expanded blocks to apply lora.
finetuning_args: "FinetuningArguments" """
) -> Dict[str, Any]: num_layers = getattr(model.config, "num_hidden_layers", None)
return { if not num_layers:
"tasks": "text-generation", raise ValueError("Model was not supported.")
"license": "other",
"finetuned_from": model_args.model_name_or_path, if num_layers % num_layer_trainable != 0:
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")], raise ValueError(
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []) "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
} )
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
trainable_layer in name for trainable_layer in trainable_layers
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names
def load_valuehead_params( def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r""" r"""
Loads value head parameters from Hugging Face Hub or local disk. Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
""" """
kwargs = { kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir
}
if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
try: try:
from safetensors import safe_open from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
}
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)) vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None return None
def prepare_model_for_training( def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
model: "PreTrainedModel", if "AutoConfig" in getattr(config, "auto_map", {}):
finetuning_args: "FinetuningArguments", config.__class__.register_for_auto_class()
output_layer_name: Optional[str] = "lm_head", if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
use_gradient_checkpointing: Optional[bool] = True, model.__class__.register_for_auto_class()
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
) -> "PreTrainedModel": tokenizer.__class__.register_for_auto_class()
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
if finetuning_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
logger.info("Upcasting weights in layernorm in float32.")
if finetuning_args.neft_alpha > 1e-6:
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear):
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
return args[0].to(output_layer.weight.dtype)
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
output_layer.register_forward_hook(fp32_forward_post_hook)
return model

View File

@@ -1 +1,4 @@
from llmtuner.train.tuner import export_model, run_exp from .tuner import export_model, run_exp
__all__ = ["export_model", "run_exp"]

View File

@@ -1 +1,4 @@
from llmtuner.train.dpo.workflow import run_dpo from .workflow import run_dpo
__all__ = ["run_dpo"]

View File

@@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple from typing import Any, Dict, List, Sequence, Tuple
import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
@@ -20,7 +21,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
padded_tensor = self.label_pad_token_id * torch.ones_like(feature) padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end] padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor) padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r""" r"""
@@ -34,10 +35,12 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
for key in ("chosen_ids", "rejected_ids"): for key in ("chosen_ids", "rejected_ids"):
for feature in features: for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({ concatenated_features.append(
"input_ids": feature["prompt_ids"] + feature[key], {
"attention_mask": [1] * (prompt_len + answer_len) "input_ids": feature["prompt_ids"] + feature[key],
}) "attention_mask": [1] * (prompt_len + answer_len),
}
)
label_positions.append((prompt_len, answer_len)) label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad( batch = self.tokenizer.pad(

View File

@@ -1,40 +1,51 @@
import torch
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import BatchEncoding, Trainer from transformers import BatchEncoding, Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
class CustomDPOTrainer(DPOTrainer): class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
beta: float, beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True, disable_dropout: Optional[bool] = True,
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid", **kwargs,
**kwargs
): ):
if disable_dropout: if disable_dropout:
disable_dropout_in_model(model) disable_dropout_in_model(model)
if ref_model is not None: if ref_model is not None:
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.is_encoder_decoder = model.config.is_encoder_decoder self.reference_free = False
self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning
self.use_dpo_data_collator = True # hack to avoid warning self.generate_during_eval = False # disable at evaluation
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0 self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
self.precompute_ref_log_probs = False
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self.beta = beta self.beta = beta
self.label_smoothing = 0
self.loss_type = loss_type self.loss_type = loss_type
self.ftx_gamma = ftx_gamma
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
@@ -44,32 +55,95 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None: if ref_model is not None:
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
if not ( if not (
getattr(ref_model, "is_loaded_in_8bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False) ): # quantized models are already set on the correct device
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = self._prepare_deepspeed(self.ref_model)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
def concatenated_forward( def concatenated_forward(
self, self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model( all_logits = model(
input_ids=batch_copied["input_ids"], input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32) ).logits.to(torch.float32)
all_logps = self._get_batch_logps( all_logps = self.get_batch_logps(
all_logits, all_logits,
batch["labels"], batch["labels"],
average_log_prob=False average_log_prob=False,
label_pad_token_id=self.label_pad_token_id,
) )
batch_size = batch["input_ids"].size(0) // 2 batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics

View File

@@ -1,20 +1,23 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments from ...hparams import ModelArguments
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding from ...train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer from ...train.dpo.trainer import CustomDPOTrainer
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments
def run_dpo( def run_dpo(
@@ -22,38 +25,39 @@ def run_dpo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding( data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
) )
# Create reference model # Create reference model
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model ref_model = model
else: else:
ref_model = create_ref_model(model_args, finetuning_args) ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta, beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
model=model, model=model,
ref_model=ref_model, ref_model=ref_model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Training # Training
@@ -69,7 +73,7 @@ def run_dpo(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model if id(model) == id(ref_model): # unable to compute rewards without a reference model
remove_keys = [key for key in metrics.keys() if "rewards" in key] remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys: for key in remove_keys:
metrics.pop(key) metrics.pop(key)

View File

@@ -1 +1,4 @@
from llmtuner.train.ppo.workflow import run_ppo from .workflow import run_ppo
__all__ = ["run_ppo"]

View File

@@ -1,27 +1,28 @@
import math
import os import os
import sys import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl import torch
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from tqdm import tqdm
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.trainer_pt_utils import remove_dummy_checkpoint from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from llmtuner.extras.logging import get_logger from ...extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -40,7 +41,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"], callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead", reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs **kwargs,
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
@@ -52,7 +53,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict() **generating_args.to_dict(),
) )
self.state = TrainerState() self.state = TrainerState()
@@ -61,7 +62,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.accelerator.state, "deepspeed_plugin" self.accelerator.state, "deepspeed_plugin"
) )
self.log_callback, self.save_callback = callbacks[0], callbacks[1] self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
if self.args.max_steps > 0: if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
@@ -71,7 +72,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if not ( if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False) getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False) or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model) self.reward_model = self._prepare_deepspeed(self.reward_model)
else: else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
@@ -111,9 +112,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info(" Num examples = {}".format(num_examples)) logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs)) logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size)) logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format( logger.info(
total_train_batch_size " Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
)) total_train_batch_size
)
)
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps)) logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs)) logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps)) logger.info(" Total training steps = {}".format(max_steps))
@@ -138,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval() self.model.eval()
# Get inputs # Get inputs
self.tokenizer.padding_side = "right" # change padding side self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], [] queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size): for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size]) mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model) mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries) queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses) responses.extend(mini_batch_responses)
@@ -154,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Run PPO step # Run PPO step
stats = self.step(queries, responses, rewards) stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
@@ -163,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True) batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards) self.log_stats(stats, batch, rewards)
except: except Exception:
logger.warning("Failed to save stats due to unknown errors.") logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1 self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control) self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0: if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict( logs = dict(
loss=round(loss_meter.avg, 4), loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4), reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"], learning_rate=stats["ppo/learning_rate"],
epoch=round(step / steps_in_epoch, 2) epoch=round(step / steps_in_epoch, 2),
) )
tqdm.write(str(logs)) tqdm.write(str(logs))
logs["step"] = step logs["step"] = step
@@ -183,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
loss_meter.reset() loss_meter.reset()
reward_meter.reset() reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join( self.save_model(
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step) os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)) )
self.save_callback.on_save( self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
) )
@@ -204,33 +209,36 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r""" r"""
Generates model's responses given queries. Generates model's responses given queries.
""" """
if self.finetuning_args.upcast_layernorm: if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model) layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate( generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
logits_processor=get_logits_processor(),
**batch
) )
if self.finetuning_args.upcast_layernorm: if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params) restore_layernorm(self.model, layernorm_params)
query = batch["input_ids"].detach().cpu() query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu() response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
queries, responses = [], [] queries, responses = [], []
for i in range(len(query)): for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0: if len(response_index) == 0:
response_length = 1 # allow empty response response_length = 1 # allow empty response
else: else:
response_length = response_index[-1].item() + 1 response_length = response_index[-1].item() + 1
queries.append(query[i, query_length:]) # remove padding from left queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right responses.append(response[i, :response_length]) # remove padding from right
return queries, responses return queries, responses
@@ -239,7 +247,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self, self,
queries: List[torch.Tensor], queries: List[torch.Tensor],
responses: List[torch.Tensor], responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead" unwrapped_model: "AutoModelForCausalLMWithValueHead",
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
r""" r"""
Computes scores using given reward model. Computes scores using given reward model.
@@ -259,17 +267,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses) batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
rewards = [] rewards = []
for i in range(values.size(0)): for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero() end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0 end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
@@ -284,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
responses: torch.Tensor, responses: torch.Tensor,
model_inputs: dict, model_inputs: dict,
return_logits: Optional[bool] = False, return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None response_masks: Optional[torch.Tensor] = None,
): ):
r""" r"""
Calculates model outputs in multiple batches. Calculates model outputs in multiple batches.
@@ -307,7 +315,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
input_ids = input_kwargs["input_ids"] input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"] attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs) logits, _, values = model(**input_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@@ -320,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for j in range(len(query_batch)): for j in range(len(query_batch)):
start = len(query_batch[j]) - 1 start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0].item() start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j]) end = start + len(response_batch[j])
if response_masks is not None: if response_masks is not None:
response_masks_batch = torch.cat( response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0 masks[j, :start] = 0
masks[j, end:] = 0 masks[j, end:] = 0
@@ -361,9 +367,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
except ValueError: except ValueError:
logger.warning( logger.warning(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" zero_to_fp32.py to recover weights" " use zero_to_fp32.py to recover weights"
) )
self._save(output_dir, state_dict={}) self._save(output_dir, state_dict={})
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model.save_checkpoint(output_dir) # wrapped model self.model.save_checkpoint(output_dir)

View File

@@ -1,8 +1,12 @@
import json import json
import torch from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Dict, List, Literal, Optional
from llmtuner.extras.packages import is_requests_available import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
@@ -21,16 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily if is_deepspeed_zero3_enabled():
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() import deepspeed # type: ignore
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active params = [model.v_head.summary.weight, model.v_head.summary.bias]
model.v_head.load_state_dict({ context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), else:
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone() context_maybe_zero3 = nullcontext()
})
with context_maybe_zero3:
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:

View File

@@ -1,22 +1,26 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math import math
from trl import PPOConfig from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.ppo.trainer import CustomPPOTrainer
from ...train.utils import create_ref_model, create_reward_model
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_ppo( def run_ppo(
@@ -25,13 +29,14 @@ def run_ppo(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) model_args, finetuning_args, training_args.do_train, add_valuehead=True
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") )
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model # Create reference model and reward model
@@ -55,7 +60,8 @@ def run_ppo(
use_score_scaling=finetuning_args.ppo_score_norm, use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm, use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards, whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False} accelerator_kwargs={"step_scheduler_with_optimizer": False},
project_kwargs={"logging_dir": training_args.logging_dir},
) )
# Create optimizer and scheduler # Create optimizer and scheduler
@@ -70,7 +76,7 @@ def run_ppo(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps), num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps num_training_steps=num_training_steps,
) )
# Initialize our Trainer # Initialize our Trainer
@@ -79,7 +85,7 @@ def run_ppo(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
generating_args=generating_args, generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model, reward_model=reward_model,
config=ppo_config, config=ppo_config,
model=model, model=model,
@@ -88,13 +94,15 @@ def run_ppo(
dataset=dataset, dataset=dataset,
data_collator=data_collator, data_collator=data_collator,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler lr_scheduler=lr_scheduler,
) )
# Training # Training
if training_args.do_train: if training_args.do_train:
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
ppo_trainer.save_model() ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -1 +1,4 @@
from llmtuner.train.pt.workflow import run_pt from .workflow import run_pt
__all__ = ["run_pt"]

View File

@@ -1,17 +1,20 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForLanguageModeling, Trainer from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_pt( def run_pt(
@@ -19,11 +22,10 @@ def run_pt(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer # Initialize our Trainer
@@ -33,7 +35,7 @@ def run_pt(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Training # Training

View File

@@ -1 +1,4 @@
from llmtuner.train.rm.workflow import run_rm from .workflow import run_rm
__all__ = ["run_rm"]

View File

@@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
@@ -20,8 +21,9 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
features = [ features = [
{ {
"input_ids": feature["prompt_ids"] + feature[key], "input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])) "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
} }
for key in ("chosen_ids", "rejected_ids") for feature in features for key in ("chosen_ids", "rejected_ids")
for feature in features
] ]
return super().__call__(features) return super().__call__(features)

View File

@@ -1,6 +1,7 @@
import numpy as np
from typing import Dict, Sequence, Tuple, Union from typing import Dict, Sequence, Tuple, Union
import numpy as np
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds preds, _ = eval_preds

View File

@@ -1,14 +1,16 @@
import os
import json import json
import torch import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer from transformers import Trainer
from llmtuner.extras.logging import get_logger from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.trainer import PredictionOutput
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -21,13 +23,10 @@ class PairwiseTrainer(Trainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
def compute_loss( def compute_loss(
self, self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
@@ -68,9 +67,9 @@ class PairwiseTrainer(Trainer):
assert div_index > 0 assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the last token except pad token for inference if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1]) chosen_scores.append(chosen_rewards[i, chosen_length - 1])
rejected_scores.append(rejected_rewards[i, rejected_length-1]) rejected_scores.append(rejected_rewards[i, rejected_length - 1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size loss = loss / batch_size
@@ -80,10 +79,7 @@ class PairwiseTrainer(Trainer):
return loss return loss
def save_predictions( def save_predictions(self, predict_results: "PredictionOutput") -> None:
self,
predict_results: "PredictionOutput"
) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.

View File

@@ -1,20 +1,24 @@
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback from ...extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.ploting import plot_loss from ...extras.misc import fix_valuehead_checkpoint
from llmtuner.model import load_model_and_tokenizer from ...extras.ploting import plot_loss
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding from ...model import load_model_and_tokenizer
from llmtuner.train.rm.metric import compute_accuracy from ...train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.trainer import PairwiseTrainer from ...train.rm.metric import compute_accuracy
from llmtuner.train.utils import create_modelcard_and_push from ...train.rm.trainer import PairwiseTrainer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_rm( def run_rm(
@@ -22,16 +26,17 @@ def run_rm(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) model_args, finetuning_args, training_args.do_train, add_valuehead=True
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") )
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
@@ -40,15 +45,17 @@ def run_rm(
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks + [SavePeftModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model() trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()

View File

@@ -1 +1,4 @@
from llmtuner.train.sft.workflow import run_sft from .workflow import run_sft
__all__ = ["run_sft"]

View File

@@ -1,11 +1,11 @@
import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from llmtuner.extras.constants import IGNORE_INDEX import numpy as np
from llmtuner.extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available from ...extras.constants import IGNORE_INDEX
) from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@@ -14,7 +14,7 @@ if is_jieba_available():
import jieba import jieba
if is_nltk_available(): if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
if is_rouge_available(): if is_rouge_available():
from rouge_chinese import Rouge from rouge_chinese import Rouge

View File

@@ -1,13 +1,15 @@
import os
import json import json
import torch import os
import numpy as np
import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
@@ -33,16 +35,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
if self.args.predict_with_generate: if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len: if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len] inputs["labels"] = inputs["labels"][:, :prompt_len]
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
if generated_tokens is not None and self.args.predict_with_generate: if generated_tokens is not None and self.args.predict_with_generate:
@@ -51,23 +53,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return loss, generated_tokens, labels return loss, generated_tokens, labels
def _pad_tensors_to_target_len( def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
self,
src_tensor: torch.Tensor,
tgt_tensor: torch.Tensor
) -> torch.Tensor:
r""" r"""
Pads the tensor to the same length as the target tensor. Pads the tensor to the same length as the target tensor.
""" """
assert self.tokenizer.pad_token_id is not None, "Pad token is required." assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory return padded_tensor.contiguous() # in contiguous memory
def save_predictions( def save_predictions(self, predict_results: "PredictionOutput") -> None:
self,
predict_results: "PredictionOutput"
) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.
@@ -79,15 +74,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info(f"Saving prediction results to {output_prediction_file}")
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) labels = np.where(
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
)
for i in range(len(preds)): for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
if len(pad_len): if len(pad_len):
preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), axis=-1) # move pad token to last preds[i] = np.concatenate(
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
) # move pad token to last
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False) decoded_labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:

View File

@@ -1,20 +1,23 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from ...extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics from ...train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer from ...train.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_sft( def run_sft(
@@ -23,27 +26,31 @@ def run_sft(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None,
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
if training_args.predict_with_generate: if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation tokenizer.padding_side = "left" # use left-padding in generation
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
) )
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict( training_args_dict.update(
generation_max_length=training_args.generation_max_length or data_args.cutoff_len, dict(
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
)) generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams,
)
)
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
@@ -54,7 +61,7 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args),
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
@@ -76,7 +83,7 @@ def run_sft(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None) metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
@@ -84,7 +91,7 @@ def run_sft(
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None) predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)

View File

@@ -1,13 +1,18 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback import torch
from llmtuner.extras.logging import get_logger from transformers import PreTrainedModel
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.train.pt import run_pt from ..extras.callbacks import LogCallback
from llmtuner.train.sft import run_sft from ..extras.logging import get_logger
from llmtuner.train.rm import run_rm from ..hparams import get_infer_args, get_train_args
from llmtuner.train.ppo import run_ppo from ..model import load_model_and_tokenizer
from llmtuner.train.dpo import run_dpo from .dpo import run_dpo
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@@ -36,19 +41,49 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
def export_model(args: Optional[Dict[str, Any]] = None): def export_model(args: Optional[Dict[str, Any]] = None):
model_args, _, finetuning_args, _ = get_infer_args(args) model_args, _, finetuning_args, _ = get_infer_args(args)
if model_args.export_dir is None:
raise ValueError("Please specify `export_dir`.")
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.")
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
if getattr(model, "quantization_method", None) in ["gptq", "awq"]: if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
raise ValueError("Cannot export a GPTQ or AWQ quantized model.") raise ValueError("Cannot merge adapters to a quantized model.")
model.config.use_cache = True if not isinstance(model, PreTrainedModel):
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size)) raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if getattr(model, "quantization_method", None):
model = model.to("cpu")
elif hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", torch.float16)
model.save_pretrained(
save_directory=model_args.export_dir,
max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format),
)
if model_args.export_hub_model_id is not None:
model.push_to_hub(
model_args.export_hub_model_id,
token=model_args.hf_hub_token,
max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format),
)
try: try:
tokenizer.padding_side = "left" # restore padding side tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left" tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(finetuning_args.export_dir) tokenizer.save_pretrained(model_args.export_dir)
except: if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except Exception:
logger.warning("Cannot save tokenizer, please copy the files manually.") logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@@ -1,15 +1,18 @@
import torch
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from llmtuner.extras.logging import get_logger import torch
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params from ..extras.logging import get_logger
from ..hparams import FinetuningArguments, ModelArguments
from ..model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -20,22 +23,24 @@ def create_modelcard_and_push(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments" finetuning_args: "FinetuningArguments",
) -> None: ) -> None:
if training_args.do_train: kwargs = {
if training_args.push_to_hub: "tasks": "text-generation",
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args)) "finetuned_from": model_args.model_name_or_path,
return "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
try: "tags": ["llama-factory", finetuning_args.finetuning_type],
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args)) }
except Exception as err: if not training_args.do_train:
logger.warning("Failed to create model card: {}".format(str(err))) pass
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
def create_ref_model( def create_ref_model(
model_args: "ModelArguments", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
finetuning_args: "FinetuningArguments",
add_valuehead: Optional[bool] = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: ) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r""" r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported. Creates reference model for PPO/DPO training. Evaluation mode is not supported.
@@ -44,11 +49,13 @@ def create_ref_model(
""" """
if finetuning_args.ref_model is not None: if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict() ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict( ref_model_args_dict.update(
model_name_or_path=finetuning_args.ref_model, dict(
checkpoint_dir=finetuning_args.ref_model_checkpoint, model_name_or_path=finetuning_args.ref_model,
quantization_bit=finetuning_args.ref_model_quantization_bit adapter_name_or_path=finetuning_args.ref_model_adapters,
)) quantization_bit=finetuning_args.ref_model_quantization_bit,
)
)
ref_model_args = ModelArguments(**ref_model_args_dict) ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora") ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer( ref_model, _ = load_model_and_tokenizer(
@@ -68,9 +75,7 @@ def create_ref_model(
def create_reward_model( def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead": ) -> "AutoModelForCausalLMWithValueHead":
r""" r"""
Creates reward model for PPO training. Creates reward model for PPO training.
@@ -81,24 +86,30 @@ def create_reward_model(
return finetuning_args.reward_model return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora": elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name: if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32 param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args) vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded." assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) model.register_buffer(
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) "default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
)
model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None return None
else: else:
reward_model_args_dict = model_args.to_dict() reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict( reward_model_args_dict.update(
model_name_or_path=finetuning_args.reward_model, dict(
checkpoint_dir=finetuning_args.reward_model_checkpoint, model_name_or_path=finetuning_args.reward_model,
quantization_bit=finetuning_args.reward_model_quantization_bit adapter_name_or_path=finetuning_args.reward_model_adapters,
)) quantization_bit=finetuning_args.reward_model_quantization_bit,
)
)
reward_model_args = ModelArguments(**reward_model_args_dict) reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora") reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer( reward_model, _ = load_model_and_tokenizer(

View File

@@ -1 +1,4 @@
from llmtuner.webui.interface import create_ui, create_web_demo from .interface import create_ui, create_web_demo
__all__ = ["create_ui", "create_web_demo"]

View File

@@ -1,24 +1,24 @@
import gradio as gr import json
from gradio.components import Component # cannot use TYPE_CHECKING here from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel
from ..data import Role
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
from .locales import ALERTS
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
from llmtuner.webui.manager import Manager from .manager import Manager
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__( def __init__(
self, self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
) -> None: ) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
@@ -26,11 +26,12 @@ class WebChatModel(ChatModel):
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
if not lazy_init: # read arguments from command line if not lazy_init: # read arguments from command line
super().__init__() super().__init__()
if demo_mode: # load demo_config.json if exists if demo_mode: # load demo_config.json if exists
import json import json
try: try:
with open("demo_config.json", "r", encoding="utf-8") as f: with open("demo_config.json", "r", encoding="utf-8") as f:
args = json.load(f) args = json.load(f)
@@ -38,7 +39,7 @@ class WebChatModel(ChatModel):
super().__init__(args) super().__init__(args)
except AssertionError: except AssertionError:
print("Please provided model name and template in `demo_config.json`.") print("Please provided model name and template in `demo_config.json`.")
except: except Exception:
print("Cannot find `demo_config.json` at current directory.") print("Cannot find `demo_config.json` at current directory.")
@property @property
@@ -63,24 +64,26 @@ class WebChatModel(ChatModel):
yield error yield error
return return
if get("top.checkpoints"): if get("top.adapter_path"):
checkpoint_dir = ",".join([ adapter_name_or_path = ",".join(
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") [
]) get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else: else:
checkpoint_dir = None adapter_name_or_path = None
yield ALERTS["info_loading"][lang] yield ALERTS["info_loading"][lang]
args = dict( args = dict(
model_name_or_path=get("top.model_path"), model_name_or_path=get("top.model_path"),
checkpoint_dir=checkpoint_dir, adapter_name_or_path=adapter_name_or_path,
finetuning_type=get("top.finetuning_type"), finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
system_prompt=get("top.system_prompt"), flash_attn=(get("top.booster") == "flash_attn"),
flash_attn=get("top.flash_attn"), use_unsloth=(get("top.booster") == "unsloth"),
shift_attn=get("top.shift_attn"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
) )
super().__init__(args) super().__init__(args)
@@ -103,22 +106,39 @@ class WebChatModel(ChatModel):
def predict( def predict(
self, self,
chatbot: List[Tuple[str, str]], chatbot: List[Tuple[str, str]],
role: str,
query: str, query: str,
history: List[Tuple[str, str]], messages: Sequence[Tuple[str, str]],
system: str, system: str,
tools: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: ) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
chatbot.append([query, ""]) chatbot.append([query, ""])
query_messages = messages + [{"role": role, "content": query}]
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
new_history = history + [(query, response)] if tools:
chatbot[-1] = [query, self.postprocess(response)] result = self.template.format_tools.extract(response)
yield chatbot, new_history else:
result = response
if isinstance(result, tuple):
name, arguments = result
arguments = json.loads(arguments)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
bot_text = "```json\n" + tool_call + "\n```"
else:
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result
chatbot[-1] = [query, self.postprocess(bot_text)]
yield chatbot, output_messages
def postprocess(self, response: str) -> str: def postprocess(self, response: str) -> str:
blocks = response.split("```") blocks = response.split("```")

View File

@@ -1,39 +1,28 @@
import os
import json import json
import gradio as gr import os
from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from transformers.utils import (
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
)
from llmtuner.extras.constants import ( import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import (
DATA_CONFIG,
DEFAULT_MODULE, DEFAULT_MODULE,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
PEFT_METHODS,
SUPPORTED_MODELS, SUPPORTED_MODELS,
TRAINING_STAGES, TRAINING_STAGES,
DownloadSource DownloadSource,
) )
from llmtuner.extras.misc import use_modelscope from ..extras.misc import use_modelscope
from llmtuner.hparams.data_args import DATA_CONFIG
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
DEFAULT_DATA_DIR = "data" DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves" DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config" USER_CONFIG = "user.config"
CKPT_NAMES = [
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
]
def get_save_dir(*args) -> os.PathLike: def get_save_dir(*args) -> os.PathLike:
@@ -48,7 +37,7 @@ def load_config() -> Dict[str, Any]:
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
except: except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@@ -65,13 +54,13 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def get_model_path(model_name: str) -> str: def get_model_path(model_name: str) -> str:
user_config = load_config() user_config = load_config()
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, []) path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "") model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
if ( if (
use_modelscope() use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE) and path_dict.get(DownloadSource.MODELSCOPE)
and model_path == path_dict.get(DownloadSource.DEFAULT) and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace path ): # replace path
model_path = path_dict.get(DownloadSource.MODELSCOPE) model_path = path_dict.get(DownloadSource.MODELSCOPE)
return model_path return model_path
@@ -90,18 +79,20 @@ def get_template(model_name: str) -> str:
return "default" return "default"
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = [] if finetuning_type not in PEFT_METHODS:
if model_name: return gr.update(value=[], choices=[], interactive=False)
adapters = []
if model_name and finetuning_type == "lora":
save_dir = get_save_dir(model_name, finetuning_type) save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir): if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir): for adapter in os.listdir(save_dir):
if ( if os.path.isdir(os.path.join(save_dir, adapter)) and any(
os.path.isdir(os.path.join(save_dir, checkpoint)) os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
): ):
checkpoints.append(checkpoint) adapters.append(adapter)
return gr.update(value=[], choices=checkpoints) return gr.update(value=[], choices=adapters, interactive=True)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:

View File

@@ -1,6 +1,16 @@
from llmtuner.webui.components.top import create_top from .chatbot import create_chat_box
from llmtuner.webui.components.train import create_train_tab from .eval import create_eval_tab
from llmtuner.webui.components.eval import create_eval_tab from .export import create_export_tab
from llmtuner.webui.components.infer import create_infer_tab from .infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab from .top import create_top
from llmtuner.webui.components.chatbot import create_chat_box from .train import create_train_tab
__all__ = [
"create_chat_box",
"create_eval_tab",
"create_export_tab",
"create_infer_tab",
"create_top",
"create_train_tab",
]

View File

@@ -1,49 +1,63 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from ...data import Role
from ..utils import check_json_schema
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.blocks import Block from gradio.blocks import Block
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_chat_box( def create_chat_box(
engine: "Engine", engine: "Engine", visible: Optional[bool] = False
visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box: with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
history = gr.State([]) messages = gr.State([])
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=2)
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1): with gr.Column(scale=1):
clear_btn = gr.Button()
gen_kwargs = engine.chatter.generating_args gen_kwargs = engine.chatter.generating_args
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1) max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
clear_btn = gr.Button()
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
submit_btn.click( submit_btn.click(
engine.chatter.predict, engine.chatter.predict,
[chatbot, query, history, system, max_new_tokens, top_p, temperature], [chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, messages],
show_progress=True show_progress=True,
).then( ).then(lambda: gr.update(value=""), outputs=[query])
lambda: gr.update(value=""), outputs=[query]
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
return chat_box, chatbot, history, dict( return (
system=system, chat_box,
query=query, chatbot,
submit_btn=submit_btn, messages,
clear_btn=clear_btn, dict(
max_new_tokens=max_new_tokens, role=role,
top_p=top_p, system=system,
temperature=temperature tools=tools,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
clear_btn=clear_btn,
),
) )

View File

@@ -1,9 +1,11 @@
import os
import json import json
import gradio as gr import os
from typing import TYPE_CHECKING, Any, Dict, Tuple from typing import TYPE_CHECKING, Any, Dict, Tuple
from llmtuner.webui.common import DATA_CONFIG import gradio as gr
from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -21,8 +23,11 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: try:
dataset_info = json.load(f) with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
except Exception:
return gr.update(interactive=False)
if ( if (
len(dataset) > 0 len(dataset) > 0
@@ -45,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
elif data_file.endswith(".jsonl"): elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f] data = [json.loads(line) for line in f]
else: else:
data = [line for line in f] data = [line for line in f] # noqa: C416
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True) return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
@@ -64,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
with gr.Row(): with gr.Row():
preview_samples = gr.JSON(interactive=False) preview_samples = gr.JSON(interactive=False)
dataset.change( dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
).then(
lambda: 0, outputs=[page_index], queue=False lambda: 0, outputs=[page_index], queue=False
) )
data_preview_btn.click( data_preview_btn.click(
get_preview, get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
) )
prev_btn.click( prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
prev_page, [page_index], [page_index], queue=False get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
) )
next_btn.click( next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
next_page, [page_index, preview_count], [page_index], queue=False get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
) )
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return dict( return dict(
@@ -99,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
prev_btn=prev_btn, prev_btn=prev_btn,
next_btn=next_btn, next_btn=next_btn,
close_btn=close_btn, close_btn=close_btn,
preview_samples=preview_samples preview_samples=preview_samples,
) )

View File

@@ -1,12 +1,15 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR import gradio as gr
from llmtuner.webui.components.data import create_preview_box
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
@@ -30,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict}) input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict( elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
))
with gr.Row(): with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1) max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
@@ -41,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox() output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir}) input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict( elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
))
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
@@ -58,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_box = gr.Markdown() output_box = gr.Markdown()
output_elems = [output_box, process_bar] output_elems = [output_box, process_bar]
elem_dict.update(dict( elem_dict.update(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, dict(
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box cmd_preview_btn=cmd_preview_btn,
)) start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
)
)
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems) cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
start_btn.click(engine.runner.run_eval, input_elems, output_elems) start_btn.click(engine.runner.run_eval, input_elems, output_elems)

View File

@@ -1,47 +1,68 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.train import export_model import gradio as gr
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS from ...train import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def save_model( def save_model(
lang: str, lang: str,
model_name: str, model_name: str,
model_path: str, model_path: str,
checkpoints: List[str], adapter_path: List[str],
finetuning_type: str, finetuning_type: str,
template: str, template: str,
max_shard_size: int, max_shard_size: int,
export_dir: str export_quantization_bit: int,
export_quantization_dataset: str,
export_legacy_format: bool,
export_dir: str,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
error = "" error = ""
if not model_name: if not model_name:
error = ALERTS["err_no_model"][lang] error = ALERTS["err_no_model"][lang]
elif not model_path: elif not model_path:
error = ALERTS["err_no_path"][lang] error = ALERTS["err_no_path"][lang]
elif not checkpoints:
error = ALERTS["err_no_checkpoint"][lang]
elif not export_dir: elif not export_dir:
error = ALERTS["err_no_export_dir"][lang] error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang]
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error yield error
return return
if adapter_path:
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
args = dict( args = dict(
model_name_or_path=model_path, model_name_or_path=model_path,
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
export_dir=export_dir, export_dir=export_dir,
export_size=max_shard_size export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset,
export_legacy_format=export_legacy_format,
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
@@ -51,9 +72,12 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
export_dir = gr.Textbox()
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100) max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_legacy_format = gr.Checkbox()
export_dir = gr.Textbox()
export_btn = gr.Button() export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
@@ -63,18 +87,24 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_name("top.lang"), engine.manager.get_elem_by_name("top.lang"),
engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.model_path"), engine.manager.get_elem_by_name("top.model_path"),
engine.manager.get_elem_by_name("top.checkpoints"), engine.manager.get_elem_by_name("top.adapter_path"),
engine.manager.get_elem_by_name("top.finetuning_type"), engine.manager.get_elem_by_name("top.finetuning_type"),
engine.manager.get_elem_by_name("top.template"), engine.manager.get_elem_by_name("top.template"),
max_shard_size, max_shard_size,
export_dir export_quantization_bit,
export_quantization_dataset,
export_legacy_format,
export_dir,
], ],
[info_box] [info_box],
) )
return dict( return dict(
export_dir=export_dir,
max_shard_size=max_shard_size, max_shard_size=max_shard_size,
export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset,
export_legacy_format=export_legacy_format,
export_dir=export_dir,
export_btn=export_btn, export_btn=export_btn,
info_box=info_box info_box=info_box,
) )

View File

@@ -1,11 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.webui.components.chatbot import create_chat_box import gradio as gr
from .chatbot import create_chat_box
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
@@ -22,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems)) elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click( load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
engine.chatter.load_model, input_elems, [info_box]
).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
) )
unload_btn.click( unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
engine.chatter.unload_model, input_elems, [info_box]
).then(
lambda: ([], []), outputs=[chatbot, history] lambda: ([], []), outputs=[chatbot, history]
).then( ).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
return elem_dict return elem_dict

View File

@@ -1,10 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates import gradio as gr
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config from ...data import templates
from llmtuner.webui.utils import can_quantize from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -14,61 +16,44 @@ def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], scale=1) lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3) model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3) model_path = gr.Textbox(scale=3)
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5) adapter_path = gr.Dropdown(multiselect=True, scale=5, allow_custom_value=True)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1) quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) template = gr.Dropdown(choices=list(templates.keys()), value="default")
system_prompt = gr.Textbox(scale=2)
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
with gr.Row():
with gr.Column():
flash_attn = gr.Checkbox(value=False)
shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
model_name.change( model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then(
get_model_path, [model_name], [model_path], queue=False get_model_path, [model_name], [model_path], queue=False
).then( ).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
get_template, [model_name], [template], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change( finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False can_quantize, [finetuning_type], [quantization_bit], queue=False
) )
refresh_btn.click( refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
)
return dict( return dict(
lang=lang, lang=lang,
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
checkpoints=checkpoints, adapter_path=adapter_path,
refresh_btn=refresh_btn, refresh_btn=refresh_btn,
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,
system_prompt=system_prompt, rope_scaling=rope_scaling,
llama_tab=llama_tab, booster=booster,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling
) )

View File

@@ -1,15 +1,18 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
from llmtuner.extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
from llmtuner.webui.components.data import create_preview_box from ..components.data import create_preview_box
from llmtuner.webui.utils import gen_plot from ..utils import gen_plot
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
@@ -28,84 +31,143 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset}) input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict( elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
))
with gr.Row(): with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
learning_rate = gr.Textbox(value="5e-5") learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0") num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000") max_samples = gr.Textbox(value="100000")
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type}) input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict( elem_dict.update(
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs, dict(
max_samples=max_samples, compute_type=compute_type cutoff_len=cutoff_len,
)) learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
compute_type=compute_type,
)
)
with gr.Row(): with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lr_scheduler_type = gr.Dropdown( lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
max_grad_norm = gr.Textbox(value="1.0") max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size}) input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
elem_dict.update(dict( elem_dict.update(
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, dict(
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size batch_size=batch_size,
)) gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
)
)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Extra config", open=False) as extra_tab:
with gr.Row(): with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1) neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
with gr.Column(): with gr.Row():
train_on_prompt = gr.Checkbox(value=False) resize_vocab = gr.Checkbox()
upcast_layernorm = gr.Checkbox(value=False) sft_packing = gr.Checkbox()
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm}) input_elems.update(
elem_dict.update(dict( {
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, logging_steps,
neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm save_steps,
)) warmup_steps,
neftune_alpha,
resize_vocab,
sft_packing,
upcast_layernorm,
use_llama_pro,
}
)
elem_dict.update(
dict(
extra_tab=extra_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
resize_vocab=resize_vocab,
sft_packing=sft_packing,
upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
)
)
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
with gr.Row():
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
name_module_trainable = gr.Textbox(scale=3)
input_elems.update({num_layer_trainable, name_module_trainable})
elem_dict.update(
dict(
freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
)
)
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=0.1, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=1) lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=1)
resume_lora_training = gr.Checkbox(value=True, scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training}) with gr.Row():
elem_dict.update(dict( use_rslora = gr.Checkbox(scale=1)
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target, use_dora = gr.Checkbox(scale=1)
additional_target=additional_target, resume_lora_training=resume_lora_training, create_new_adapter = gr.Checkbox(scale=1)
)) additional_target = gr.Textbox(scale=2)
input_elems.update(
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
)
elem_dict.update(
dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_target=lora_target,
use_rslora=use_rslora,
use_dora=use_dora,
create_new_adapter=create_new_adapter,
additional_target=additional_target,
)
)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=3) dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
refresh_btn.click( refresh_btn.click(
list_checkpoint, list_adapters,
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")], [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model], [reward_model],
queue=False queue=False,
) )
input_elems.update({dpo_beta, reward_model}) input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn)) elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
@@ -118,7 +180,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox() output_dir = gr.Textbox()
with gr.Row(): with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False) process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Box():
@@ -135,20 +197,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort, queue=False) stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems) resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update(dict( elem_dict.update(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, dict(
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer cmd_preview_btn=cmd_preview_btn,
)) start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
output_box.change( output_box.change(
gen_plot, gen_plot,
[ [
engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.finetuning_type"), engine.manager.get_elem_by_name("top.finetuning_type"),
output_dir output_dir,
], ],
loss_viewer, loss_viewer,
queue=False queue=False,
) )
return elem_dict return elem_dict

Some files were not shown because too many files have changed in this diff Show More