646 Commits

Author SHA1 Message Date
hiyouga
debfd46749 release v0.5.2
Former-commit-id: 0189867816b0eab92fb2a1b5f1b1da079bd161a7
2024-02-20 11:12:43 +08:00
hiyouga
5ccf8fcd6b update webui
Former-commit-id: 9e0f7c362d40b78d57e77d52eaa96e678cebadcd
2024-02-19 16:49:58 +08:00
hiyouga
7bd1991513 add test scripts
Former-commit-id: fdaa4843961257b48cc32d83d30f2efe18b9fd5a
2024-02-19 02:09:13 +08:00
hiyouga
456e4ca569 fix safetensors
Former-commit-id: 06478ae5302d5fc6eb7afedc69335ce2f32808c6
2024-02-18 18:12:16 +08:00
hiyouga
6bf0fe4913 fix #2481
Former-commit-id: 2a4e3e4a26a2fad77ccc476be7d45434b8af4a55
2024-02-15 19:07:47 +08:00
hiyouga
596b6828cb support llama pro #2338 , add rslora
Former-commit-id: 40d659b7f30dd5a004703c176ec1f22dc864e505
2024-02-15 02:27:36 +08:00
hoshi-hiyouga
b403f8d8a8 Merge pull request #2474 from younesbelkada/add-hf-tags
FEAT: add HF tags for models that have been trained with llama-factory
Former-commit-id: f35d96817e61da9fa7789b93b0350c9f95afc40a
2024-02-14 10:26:03 +08:00
younesbelkada
590b6c2143 add v1 hf tags
Former-commit-id: a29cc9f4472c95cd6a43ea350ab728e0a8069c6e
2024-02-13 05:58:49 +00:00
hiyouga
5537ef1e7d fix #2471
Former-commit-id: a408be8be1cf99cd4468a9905c27ec454f312b9a
2024-02-12 21:07:46 +08:00
hiyouga
5f83860aa1 add option to disable version check
Former-commit-id: fd769cb2de696aee3c5e882237e16eace6a9d675
2024-02-10 22:31:23 +08:00
hiyouga
62b6a7971a update data/readme
Former-commit-id: aa566e3cea5bc75688b4399a9da07be0b35b921c
2024-02-10 21:04:29 +08:00
hiyouga
1d16e87c5f update default template
Former-commit-id: f32b55649a9f95109a6d180216eb67f959d060da
2024-02-10 16:44:47 +08:00
hiyouga
1955a8ea5a improve aligner
Former-commit-id: cc7296b92e10c24967fc753393275b71d300683f
2024-02-10 16:39:19 +08:00
hoshi-hiyouga
a41fa6e730 Merge pull request #2462 from mnmueller/main
Enable Parsing of SlimOrca

Former-commit-id: 99eed520b87152ca6b89c2a068b09200fd45f30d
2024-02-09 22:55:48 +08:00
hiyouga
b98a64448a improve fix tokenizer
Former-commit-id: 57b138abad6397596bc47be94e092e8fabedc06f
2024-02-09 14:53:14 +08:00
Mark Mueller
1ce82f391a Slim Orca data parsing
Former-commit-id: f2d8efede7e20edafed0d5446eb64f2d419949b1
2024-02-08 19:32:20 +01:00
Mark Mueller
4d473894fd Slim Orca data parsing
Former-commit-id: ca57d27c39d4e7bc3dd7c3207a23d23d2cbd446b
2024-02-08 17:56:18 +01:00
Mark Mueller
5788b7c7d0 Slim Orca data parsing
Former-commit-id: 3016427be4e63fd25f40bc5a0d1f8cedc0997334
2024-02-08 17:54:18 +01:00
Mark Mueller
04515f6b55 Slim Orca data parsing
Former-commit-id: 4dca3907964d27abc2b21eb55c75676901c98912
2024-02-08 17:52:36 +01:00
Mark Mueller
96f8ccf3d5 SlimOrca aligner
Former-commit-id: 928dda93867c2327a7957c04648592044ccf9daf
2024-02-08 08:28:32 -08:00
hoshi-hiyouga
2c3ef480a6 Merge pull request #2423 from mayflower/main
Support for german sft and dpo

Former-commit-id: 8e282e4e6bee6493b1bd38ba239ca49a6a840a92
2024-02-07 15:58:20 +08:00
hiyouga
fa6873122c Update tests.yml
Former-commit-id: c882b7cf339304ff16a36b1544a3b5f1194ef346
2024-02-07 01:18:22 +08:00
hiyouga
34bc0c22b1 lint
Former-commit-id: 6b1f89b6494e9b6b087fe90600617a3024e014e5
2024-02-07 01:10:04 +08:00
hiyouga
e5484b2729 Update pyproject.toml
Former-commit-id: 650251ea77fae2e2595ca804f49efdd230dbb5b1
2024-02-07 00:45:58 +08:00
hiyouga
f67f781fed update gc kwargs
Former-commit-id: 0cb81c156bc8c21a4bbdd3289a491f78dfcaf730
2024-02-07 00:38:24 +08:00
hiyouga
b564b97b7e fix #2438
Former-commit-id: 412d856eeada2abcea598fac0a8d35ae90cc9c01
2024-02-06 15:23:08 +08:00
hiyouga
0dd68d1e06 add models
Former-commit-id: 0fdf61b2f765c125acda4f406eb25b3e59e75db2
2024-02-06 14:57:23 +08:00
hiyouga
73f40f1ca4 support qwen1.5
Former-commit-id: 8a03a572b058c5cc4ff598670dc8595b2b97e374
2024-02-06 00:10:51 +08:00
hoshi-hiyouga
ea53bebac4 fix #2436
Update test_toolcall.py

Former-commit-id: 39c539b6470c532ac639efbd2a1c485d2f5d485f
2024-02-05 22:55:28 +08:00
hoshi-hiyouga
00418012bd Update test_toolcall.py
Former-commit-id: f50a684a9d6fc2351436d3d7020dc84bc1553a5d
2024-02-05 22:51:03 +08:00
hoshi-hiyouga
5f3d8c514b Update test_toolcall.py
Former-commit-id: 97bcae546ab80737a906e5e28953f41b657f6c99
2024-02-05 22:50:43 +08:00
tao.jun
cb39a3f1c4 Update test_toolcall.py
Add openai version notes

Former-commit-id: 9ea4ab214e64f73ec902e76b82fc42419571fd66
2024-02-05 20:49:23 +08:00
Johann-Peter Hartmann
4d78fe6ece Merge branch 'hiyouga:main' into main
Former-commit-id: efbb0153981d0650f3a581e324b83054ca8063c1
2024-02-04 13:55:00 +00:00
hiyouga
a3e3ea9846 fix #2421
Former-commit-id: 43918c12310f7560d3820e5c6d72964309afeb8b
2024-02-04 21:02:55 +08:00
Johann-Peter Hartmann
feba34e82d Merge branch 'hiyouga:main' into main
Former-commit-id: 0395d0aafb69e86645e6b0a36b8f8dadb82219e0
2024-02-04 12:51:25 +00:00
hiyouga
e134013e04 fix reserved label len
Former-commit-id: b06d6c05a1911f329252a7572240048e456affdc
2024-02-04 17:54:26 +08:00
hiyouga
5589d0296a fix #2420
Former-commit-id: 7a34087e4db62e603c9a9a26d8ff3910d7b10c40
2024-02-04 15:51:47 +08:00
hiyouga
de0ebab464 fix #2189
Former-commit-id: b3d81b229d376671e1c12229aeb487b0d84f2548
2024-02-04 00:47:37 +08:00
hiyouga
f2e7122a96 bump up transformers version
Former-commit-id: 82f4d4301ed9f31b160d6313a1d2d44a22865f4d
2024-02-04 00:01:16 +08:00
hiyouga
996cc5d900 fix #2397
Former-commit-id: 7404692808f2288d539668d364965ad104dacadb
2024-02-03 23:45:31 +08:00
hiyouga
a2ae5bd867 add hint for freeze #2412
Former-commit-id: 9600c93633629605573d908019563fa3870ad6f8
2024-02-03 23:38:56 +08:00
hiyouga
5fa52e87cb fix #2376
Former-commit-id: 8e2cfa7cca21b7fd4538d72114e36f704bcc82fe
2024-02-03 23:14:31 +08:00
hiyouga
bcd76d2c7a support minicpm #2404
Former-commit-id: 4449e91cbee8fd804cf8bf1ff6b9f5301fde94ed
2024-02-03 22:36:46 +08:00
Johann-Peter Hartmann
36fcbedc11 add simple german chatml template chatml_de
Former-commit-id: 9f1d67c09f1af2c7aa383adec66842cacde92e33
2024-02-03 09:01:15 +01:00
Johann-Peter Hartmann
1dad01cc53 Merge branch 'hiyouga:main' into main
Former-commit-id: c350237d891df7edd7e681f9da5ac1446fdeb568
2024-02-03 08:43:12 +01:00
hoshi-hiyouga
5fb21f6e54 Merge pull request #2411 from lxsyz/main
fix eos_token_id=0 bug

Former-commit-id: 019a353e74ec70a9a2d8987df1ed19483413211a
2024-02-02 17:38:16 +08:00
Fallen Angel
08dfac8352 fix eos_token_id=0 bug
when eos_token_id=0, will never add eos_token

Former-commit-id: 576b4881c386d897462a875b28066ce9d6e06dd5
2024-02-02 17:34:48 +08:00
Johann-Peter Hartmann
956751e419 Merge branch 'hiyouga:main' into main
Former-commit-id: 25b0a11c715f87812edba1ca14d3122a75f421de
2024-01-31 14:05:52 +01:00
hiyouga
fe2ae04c91 fix #2388
Former-commit-id: 203a36c9adfd9aa0f35fbf8089c9138534d68c53
2024-01-31 17:23:56 +08:00
hiyouga
5b8712d061 fix autoset attn impl, update data readme
Former-commit-id: 34a6e5f82baf45cc8dbb11f9f7ab4a480ab7ec5c
2024-01-31 11:58:07 +08:00
Johann-Peter Hartmann
dc7ff90c1e Add support for german datasets
Former-commit-id: bbc038aa236952597e97d1ccf1ae2d64a16339b5
2024-01-30 10:18:01 +01:00
hiyouga
1ace676170 fix #2320
Former-commit-id: e0b0c4415aaf80e75f6dd4f3777a0616b0e60f84
2024-01-24 16:19:18 +08:00
hoshi-hiyouga
8947a87b95 Merge pull request #2319 from ftgreat/main
Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3

Former-commit-id: 0fadcd5f9deb9f03d341b6611c15f337f07e32d1
2024-01-24 15:32:26 +08:00
ldwang
786a2f1103 Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.
Signed-off-by: ldwang <ftgreat@gmail.com>

Former-commit-id: 5f50c02f0e425737cd80abdf8fde9e25abf13083
2024-01-24 15:25:31 +08:00
ldwang
36ac14a566 Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.
Signed-off-by: ldwang <ftgreat@gmail.com>

Former-commit-id: d1413dcec8a3b1d671f240b82a689c72b54d7b93
2024-01-24 14:43:16 +08:00
hiyouga
7a048fc91d add hint
Former-commit-id: c540ef41bda61993b83ef8cfe3c84b1d169e984c
2024-01-22 23:32:01 +08:00
hoshi-hiyouga
3f3756b113 Merge pull request #2283 from A-Cepheus/main
fix: ZeRO3 does not work with MoE models
Former-commit-id: f5ea760abec2aac8d29ce5c945647be05648e676
2024-01-22 23:28:45 +08:00
hoshi-hiyouga
b36c4b99cc Update patcher.py
Former-commit-id: 33556cc6b0b65cc6db02e66f4f6e75112c33d966
2024-01-22 23:27:39 +08:00
hoshi-hiyouga
9856a2276e Update tests.yml
Former-commit-id: 34151675388701afa40220729a63b0d7b2fa2a7c
2024-01-22 23:22:15 +08:00
hoshi-hiyouga
b6dc3ed3ad Create tests.yml
Former-commit-id: 9443ad76b7ef3ef1f3d184ef60652947d2c30609
2024-01-22 23:13:04 +08:00
hiyouga
75be329994 fix #2282 and update tool prompt
Former-commit-id: 1c412f803866bde32b76f7c26c7b464b6b3651f3
2024-01-22 22:27:30 +08:00
hiyouga
1fe1ca1c8b add orion models
Former-commit-id: a34db89d2a281d1a1ace29dfd5bd5d4ff7c2f657
2024-01-22 21:26:53 +08:00
A-Cepheus
882a6a1d51 🐞 fix: typo
Former-commit-id: 57a3687ecd23237559aee0e8e811b782846f2415
2024-01-22 16:04:39 +08:00
A-Cepheus
712ab4ae7a 🐞 fix: typo, move MoE fix to patcher
Former-commit-id: 4ff28e99ff9b48df7150591c6bbd3723f22b7715
2024-01-22 16:01:58 +08:00
A-Cepheus
18ad259fb3 fix: ZeRO3 does not work with MoE models
Former-commit-id: b2844c049a88ea89f8e1812e2d2e8662b4002965
2024-01-22 15:21:14 +08:00
hiyouga
fe4d93c6db add array param format
Former-commit-id: bf910f8a5b21ee552fa9ab069610a3f5f611de57
2024-01-21 22:17:48 +08:00
hiyouga
c6ba588e37 update tool test
Former-commit-id: 1d63ccc2866632596310235de15fdff660f6bee5
2024-01-21 19:41:46 +08:00
hiyouga
3fda60fca0 fix api
Former-commit-id: cca004da28aaaa0788eaea62b83d3402b38a3011
2024-01-21 19:15:27 +08:00
hiyouga
96531a0ef8 fix #2268
Former-commit-id: 300ecf9b9d7fd99fbb68f3d086e3ad973c2f894e
2024-01-21 14:11:38 +08:00
hiyouga
7abc3065fb tiny fix
Former-commit-id: 66839ae94985ddfa13eca4542127119c919b9648
2024-01-21 13:26:12 +08:00
hoshi-hiyouga
013ded4bac Merge pull request #2266 from yhyu13/fix_export_model_dtype
Remove manully set use_cache; torch_dtype is not str, save model as b…

Former-commit-id: 8c0827ba92a458e18c3b68af0330af3a65149f96
2024-01-21 12:40:39 +08:00
hoshi-hiyouga
010c3c7348 Merge branch 'main' into fix_export_model_dtype
Former-commit-id: 6c7d2729f28eb37a97820d73c05648eb7fb2ca87
2024-01-21 12:40:24 +08:00
hoshi-hiyouga
bf075c075c Update tuner.py
Former-commit-id: 691420661f7115f809e76484c1f29f74637e7cd0
2024-01-21 12:39:38 +08:00
hoshi-hiyouga
41b34e5f60 Merge pull request #2262 from fenglui/main
fix torch_dtype check of export_model

Former-commit-id: 37cacf73a534fed1b06b4f3c6724f3568ce095e3
2024-01-21 12:34:37 +08:00
hiyouga
5a889398e7 format
Former-commit-id: f28a1a0c1cdd0062ad7b6c2826f8ec107a200cff
2024-01-21 12:34:17 +08:00
hoshi-hiyouga
054cae86d8 Merge pull request #2264 from seoeaa/main
add russian lang

Former-commit-id: 15d1747de54efe69ed9f4cfd8f296fe8dd09a5c9
2024-01-21 12:25:24 +08:00
yhyu13
cd1cb8b83c Remove manully set use_cache; torch_dtype is not str, save model as bfloat16 used to fail;
Former-commit-id: 75557fb5df16fd6eda7586cf041a16822dcfee8e
2024-01-21 11:12:15 +08:00
Aleksandr
a34779c027 add russian lang
Former-commit-id: f8ce6d75b56439027bb17ff4e59eeb9eb3b9bd34
2024-01-21 04:28:14 +03:00
fenglui
d19cb77d74 fix torch_dtype check of export_model
Former-commit-id: 8813181b6bffa76e5c7cb1f4caceada611c54b9d
2024-01-21 05:01:53 +08:00
hiyouga
ab67528e89 release v0.5.0 (real)
Former-commit-id: 2146e1d9195c179fa8f92144ec2b7034e1a9f942
2024-01-21 01:54:49 +08:00
hiyouga
27f281480a finish agent
Former-commit-id: d8d9d3afe32725fe79120fcd1a0970fdcdc45625
2024-01-21 01:47:33 +08:00
hiyouga
50459a39f4 fix api
Former-commit-id: a4149fbcd600d4f3815f9353e5e92c569719bed6
2024-01-21 00:03:09 +08:00
hiyouga
5c9815ef6f fix internlm2 template
Former-commit-id: ae05b23eb86555dbfafc174aa6ceff736e7fc9fa
2024-01-20 23:33:50 +08:00
hiyouga
aed00a97b6 fix cli_demo
Former-commit-id: e8336b3653f43618cf7cd70f8da004208de970c0
2024-01-20 23:27:10 +08:00
hiyouga
7543dc4a9d fix #2260
Former-commit-id: ba97550671811a27177306dd231bb427130b26fb
2024-01-20 23:22:09 +08:00
hiyouga
841fa0030f release v0.5.0
Former-commit-id: 602bb9b685009b9af234499be278404721542ac7
2024-01-20 20:21:39 +08:00
hiyouga
66e0e651b9 format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
2024-01-20 20:15:56 +08:00
hiyouga
1750218057 fix tests
Former-commit-id: 23f97bd437424ef43b2b84743d56acc5d1ca70d5
2024-01-20 19:58:04 +08:00
hiyouga
80637fc06d support longlora for main branch
Former-commit-id: f869501ad4c368df26534c41f62c6d63c6be17dd
2024-01-20 19:25:22 +08:00
hoshi-hiyouga
8efc055511 Merge pull request #2201 from liu-zichen/token_embed_resize
support resize embed for zero3

Former-commit-id: c0d1b5e3aef70da6b115614bd1ed539a76d6547a
2024-01-20 17:45:38 +08:00
hiyouga
be61bfda93 add upcast_lmhead option
Former-commit-id: 7ef69a1697c11ff13e7503360e40ef36cfb1c345
2024-01-19 23:54:25 +08:00
hiyouga
1a39f529c0 set use_reentrant=False
Former-commit-id: efa2e27d5ef6eaeb7baa7551c651ef10ab31400c
2024-01-19 23:29:54 +08:00
hiyouga
0868d5c550 fix #2249
Former-commit-id: 7ec64588c541422875adfdaf5692a27d05b96cb9
2024-01-19 21:44:32 +08:00
hiyouga
384f0e7678 add bf16 lora option
Former-commit-id: 58e7d7ff0cf9bf30e53b3eb12576f38d31976413
2024-01-19 16:29:03 +08:00
hiyouga
9b390c4bea fix function formatter
Former-commit-id: 363a87376ad8fe4149b387f7ccd60f31f2a5fdf7
2024-01-18 16:01:07 +08:00
hiyouga
42a13fec46 Update tuner.py
Former-commit-id: db30107385f100f88c9370abea6692ce6030a0c9
2024-01-18 15:06:02 +08:00
hiyouga
790acc4c17 fix templates
Former-commit-id: 382cc48b2a823b9a7d4ccf2c2a163f0e5b6e3169
2024-01-18 14:49:52 +08:00
hiyouga
b74cf27538 fix rm dataset
Former-commit-id: fa6f810026a59cecce813a696b2fdf15ba502fc4
2024-01-18 14:45:37 +08:00
hiyouga
ffc874ec6f fix pretrain data loader
Former-commit-id: 2a812b706ecc527013e79edc504ec18a4193123d
2024-01-18 14:42:52 +08:00
hoshi-hiyouga
546d6bd0b2 Merge pull request #2226 from hiyouga/dev
support function calling

Former-commit-id: 69391464f0d3fb0e2ef76e6b6fac51c119d66b53
2024-01-18 14:31:28 +08:00
hiyouga
8b68ca029e update readme
Former-commit-id: 11e0c732c4968b083f60a0bb6f7bb5dd5ca2ba56
2024-01-18 14:30:48 +08:00
hiyouga
502f84b30c add tool hint
Former-commit-id: 64734ffe2f45f80a1e33c2a72330b2ab1e58feb3
2024-01-18 13:19:09 +08:00
hiyouga
b7df920860 fix dataset
Former-commit-id: a7ce244a6d83d62f5bbecc588f1978e3791fd3b3
2024-01-18 12:59:30 +08:00
hiyouga
e4a424cb6a enable cutoff len
Former-commit-id: e9513d300c338dfcae98eee7d057bfd00da2da0e
2024-01-18 12:25:42 +08:00
hiyouga
d8affd3967 add tool test
Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
2024-01-18 10:26:26 +08:00
hiyouga
a423274fd9 support function calling
Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
2024-01-18 09:54:23 +08:00
hiyouga
f7329b1a0e Update llamafy_internlm2.py
Former-commit-id: 3ca5915a4fcd3d28d10a47bf9f2188b5cf8393a8
2024-01-18 01:12:31 +08:00
hiyouga
48eb07c956 Update llamafy_internlm2.py
Former-commit-id: 69b3cb768eda57b63f47cd35e5da3a59b57b7853
2024-01-18 01:00:16 +08:00
hiyouga
636d8a886c Update llamafy_internlm2.py
Former-commit-id: 1f1a7bcee5a5bb0fa17b13aa6393bfba89451dd7
2024-01-18 00:49:31 +08:00
hiyouga
97b52c7fdf fix llamafy scripts
Former-commit-id: 99ff69c36767d4397a4a61e89317ec8c0c295c1e
2024-01-18 00:37:37 +08:00
hiyouga
344412e66e fix llamafy_internlm2
Former-commit-id: a309375d020dedc313f3b6921fb53d932f156e8b
2024-01-18 00:26:14 +08:00
hiyouga
5cdea14cdf add llamafy_internlm2
Former-commit-id: 7b71767ef67cd5f246f52fb7e74b36bd26774a6c
2024-01-18 00:17:41 +08:00
hiyouga
7b1a56b96f support export push_to_hub #2183
Former-commit-id: fac09da7123a500d255de74810a8d057fb5c5f07
2024-01-16 23:59:42 +08:00
hiyouga
d1ec884e75 fix #2195
Former-commit-id: 801f7279693a0c785480ea67d663d99f4ca653da
2024-01-16 23:53:50 +08:00
liuzc
aa72a4349e support resize embed for zero3
Former-commit-id: b5464f5699b13bb118ac57ebc40b3cf9eb030396
2024-01-16 15:16:20 +08:00
hiyouga
5ab7fd0842 tiny fix
Former-commit-id: 6b1e9207e988c253a808e6bb26e3af9d071b77bc
2024-01-15 23:34:23 +08:00
hoshi-hiyouga
86d5e9802a Merge pull request #2194 from junuMoon/patch-1
fix: typo on README.md
Former-commit-id: a066a633a1a4b50cd6dc6b50701e35532fe788c1
2024-01-15 20:21:28 +08:00
Junu Moon(Fran)
18df39e3a1 fix: typo on README.md
Former-commit-id: 372066b559305a1428c88fbd6b01e332bfd5e3e1
2024-01-15 19:50:35 +09:00
hiyouga
cfe1e24471 support solar 10.7B #1907
Former-commit-id: ecf9b35c612e5514dd25b0d15835d28447a7437e
2024-01-14 00:30:30 +08:00
hiyouga
2edbe87a8c Update README_zh.md
Former-commit-id: e6d704c383e36abe8e27b3834f41d95890858425
2024-01-14 00:17:28 +08:00
hiyouga
880055bc90 support deepseek moe
Former-commit-id: 07fbb32496b9b81c4cfe67cb9a15a6b2c43852c3
2024-01-14 00:14:49 +08:00
hiyouga
ad99bd0a14 fix phi modules
Former-commit-id: 68d7e925ec51b6ee277513de8f61ac18a8378b98
2024-01-13 23:12:47 +08:00
hiyouga
c5f099138d fix #2147
Former-commit-id: 49445a03cd46af4e7036cf444cd041dfab2d8941
2024-01-12 03:30:56 +08:00
hiyouga
6e64e02f71 fix #2164
Former-commit-id: abe23bb4aca4fa571ebafc329ec9a9d457e37d41
2024-01-12 00:27:57 +08:00
hoshi-hiyouga
f95f6ec009 Merge pull request #2163 from JessyTsu1/main
请求添加"Projects using LLaMA Factory"

Former-commit-id: fa9abb430b204fabe4c1b3a569225695ae0cbc29
2024-01-11 23:33:29 +08:00
JessyTsu1
8aeecc20e1 Update README.md
Former-commit-id: 547d4df5c7a1d6dd95cfed37229701ce507b421c
2024-01-11 23:18:29 +08:00
JessyTsu1
38d0f6c63f Update README_zh.md
Former-commit-id: 8677309a38140ec1e1be3f81d0b2024df3f16c21
2024-01-11 23:17:48 +08:00
JessyTsu1
ac8534a9e7 Update README.md
Former-commit-id: dcd4858fd2c2ac4d3cce8a369dc9991108c03821
2024-01-11 23:17:00 +08:00
hiyouga
73cab9d9d4 fix #2161
Former-commit-id: 9acd5a2b678cd07f8e3b48eca76c4cbacb559e37
2024-01-11 17:04:13 +08:00
hiyouga
64246d42d2 improve web ui
Former-commit-id: 5c0148c018b12b52bc5748acfd6ad43836f2edb5
2024-01-10 12:37:45 +08:00
hiyouga
6fa6d4532e improve model export
Former-commit-id: d1b795aac1fccbcb8a9ec2057065c33b46ce1a5a
2024-01-09 22:26:24 +08:00
hiyouga
92b9956c06 modify weight name
Former-commit-id: 3f3c528fa8056dc1952ea5293bad7e55187983ff
2024-01-09 20:22:47 +08:00
hiyouga
4d6669c268 fix #1789
Former-commit-id: d86455f685fa531e651333e00b4fe54d895cf2e4
2024-01-09 18:31:27 +08:00
hiyouga
89f4ae51f9 fix #2127
Former-commit-id: 5a1aa33fa9b546ab520f0ba4cb9d996b87eb71ca
2024-01-09 14:49:13 +08:00
hiyouga
af0659f573 fix #2125
Former-commit-id: 46a22f4daeafac5b0a695212d060960ff53af613
2024-01-08 21:42:25 +08:00
hoshi-hiyouga
45a10d501e Merge pull request #2117 from dasdristanta13/main
Update requirements.txt With einops dependency

Former-commit-id: af0c05f1cffc7fc0fc74d514783333501f83f59e
2024-01-07 23:56:53 +08:00
Dristanta Das
e529ff1245 Update requirements.txt With einops dependency
Former-commit-id: 0b47b13cb34cace6fa0b6d0c58ca16fb01b3a5e9
2024-01-07 21:03:30 +05:30
hiyouga
b29371dc87 tiny fix
Former-commit-id: 06b854fe15eb4cf4ff8d6f5570068d9e74a2f1b3
2024-01-07 17:17:18 +08:00
hiyouga
0bef890000 fix api server
Former-commit-id: cedd80ba56c0090487f65f4b1227e5615943997f
2024-01-07 17:14:42 +08:00
hiyouga
75fe1404b1 improve model export
Former-commit-id: 31255147a566a23ce1a48402662d14af8ac267ab
2024-01-05 18:51:49 +08:00
hiyouga
b460c9372f fix #2098
Former-commit-id: e62d9158cffbf1044396597ddaf15b1c0bc5f954
2024-01-05 17:11:26 +08:00
hiyouga
c3e574ceaa fix qwen template
Former-commit-id: c1923e0daa02b49ac07e96ce29877729acc78d31
2024-01-05 16:14:56 +08:00
hiyouga
04ae80a52e fix #2081
Former-commit-id: ec4b539b6c0be11e15d273025c414b694bbd6c9a
2024-01-04 23:19:08 +08:00
hiyouga
a7ff095399 fix #2090
Former-commit-id: 13ec720990a88b01f7f5e2a99a87f95128dc3537
2024-01-04 23:05:08 +08:00
hiyouga
a655dcebaf fix #2067
Former-commit-id: 6cfdeea5261fd5bf6f91ba2bb3efb921a2f3e866
2024-01-04 22:53:03 +08:00
hiyouga
8c74851b70 fix dispatch
Former-commit-id: deda82638716506dc690902c51276bb1eb0ddd5e
2024-01-03 16:33:16 +08:00
hiyouga
7168392a51 fix valuehead patch
Former-commit-id: d9cb98362b58b28ae0ee207e7c07e75e5d810876
2024-01-03 16:19:23 +08:00
hiyouga
ccc5b324fe fix rm server
Former-commit-id: 81bc1638682a9fd01518f9f25250a6b584d2a9e6
2024-01-03 15:30:46 +08:00
hiyouga
e85c205a81 fix #2014
Former-commit-id: 077f6bf64e50f01f62aa4a957438bedc4e7925b3
2023-12-29 15:17:22 +08:00
hiyouga
7e225be16e add yuan model
Former-commit-id: 6a0377e2e51633bd5fb10fa8628e554565c5ee3e
2023-12-29 13:50:24 +08:00
hiyouga
ebb32e85f8 fix version
Former-commit-id: dd7500b65d0d548441eece101b60d51fa619cc0f
2023-12-29 04:53:36 +08:00
hiyouga
90d279f39f fix args
Former-commit-id: ff18f327a3dc96d9677ef32841e8f29ab2eeb7ef
2023-12-28 18:47:19 +08:00
hiyouga
af3f5b6e16 fix export format
Former-commit-id: 7c82bd396b9e6ff395850ad544d95cbf1b7557cd
2023-12-28 18:40:46 +08:00
hiyouga
53d7c5109f fix ppo trainer
Former-commit-id: ca5b5823b03822ef899405d233a82396be997f44
2023-12-28 18:09:28 +08:00
hiyouga
bf381563ff add model link
Former-commit-id: 159729929516f68aa1f43a852ed50ca0fac81523
2023-12-25 19:44:38 +08:00
hiyouga
de4b9334e1 tiny update
Former-commit-id: 4417b8ee20b381c964f452f52081667dfa33cd7b
2023-12-25 18:29:34 +08:00
hiyouga
c33fbea469 fix bug
Former-commit-id: b06faa1be3f5aa5e0fa31aa31314c213c36c3442
2023-12-24 19:20:12 +08:00
hiyouga
921f593632 update loader
Former-commit-id: 080d8eab858217ca58bffe719d5ffde7579c5bda
2023-12-24 19:10:23 +08:00
hiyouga
940403720a update patcher
Former-commit-id: d6d7b6670847ce4ea10353c5b126214542b45c2b
2023-12-23 15:24:27 +08:00
hiyouga
f869e44fe5 fix #1909
Former-commit-id: 3e93c33af9f80e28c9f30af9b7ba20757358afb4
2023-12-23 14:42:20 +08:00
hiyouga
bcc92919a0 update readme
Former-commit-id: d3dea7a926e9d356a39ca2033b03be7f559cc143
2023-12-23 02:17:41 +08:00
hiyouga
306a70c7ba fix unsloth dtype
Former-commit-id: fd22e6546ce5f38a6a075cf894aafc3d206b2fcd
2023-12-23 01:59:49 +08:00
hiyouga
d358d955e5 fix dpo trainer
Former-commit-id: c160dd7cd86e296e32775ace2e4258a473449c41
2023-12-23 01:51:55 +08:00
hiyouga
0fdd6074c3 llama board: add unsloth
Former-commit-id: 9477e6f28808ae9deadada1f6cf679a29542c271
2023-12-23 00:35:53 +08:00
hiyouga
6faf9c35a9 support unsloth
Former-commit-id: b857f00234b90b785d82ca7cdb29af3d948b1a7b
2023-12-23 00:14:33 +08:00
hoshi-hiyouga
1066898e32 Merge pull request #1953 from ShaneTian/model-load-bf16
Fix slow model initialization in bfloat16 dtype.

Former-commit-id: 69daf107c4561f807ceae066f04d432323699cef
2023-12-22 17:29:54 +08:00
ShaneTian
d05febe5de Fix slow model initialization in bfloat16 dtype.
Former-commit-id: cf2e2f6f9b7f09b1e2faf6fbc413e3f62e3846c7
2023-12-22 16:27:28 +08:00
hiyouga
67f7034a21 fix param type
Former-commit-id: 11b99f344416ade1cdac52e11ba7f36fcf689221
2023-12-21 17:33:01 +08:00
hiyouga
79f301a2c6 fix ds zero3 check
Former-commit-id: 7f50705b1d821d287bd854211319f697f57b25db
2023-12-21 01:19:22 +08:00
hiyouga
31cbc67986 match version
Former-commit-id: 16db52522584a8e084d4db2a7c253c8b88f27371
2023-12-20 22:17:35 +08:00
hoshi-hiyouga
fe66bf3663 Merge pull request #1932 from ShaneTian/main
Update transformers to 4.36.2 to resolve multi-node saving bug.

Former-commit-id: 5c55907a57e8327134e2c982c838a53c9fa42f51
2023-12-20 22:13:28 +08:00
ShaneTian
4691d4b35d Update transformers to 4.36.2 to resolve bug when saving a checkpoint in the multi-node setting.
Former-commit-id: 3173f8e51eec5e8f488e3dfc54ad371b640d6b87
2023-12-20 22:00:41 +08:00
hiyouga
acf5241845 fix stop words
Former-commit-id: 6ce6cac9fa8f0af33697e824cf93a9a80cdbd064
2023-12-20 19:06:43 +08:00
hiyouga
2bce99b82f fix yi template #1895
Former-commit-id: 05b4fa1e2b13a15ee261a151ac8cd0a2ebdf5edc
2023-12-20 18:58:16 +08:00
hiyouga
3c330869ef improve quantization
Former-commit-id: 4dde60017ad8208dfea0b2bb61df6a14a35d03e0
2023-12-20 18:27:16 +08:00
hiyouga
dba1af4841 add max_memory for gptq #1923
Former-commit-id: 9afc42c8b999fbbc206d9a467ca5795b27a10096
2023-12-20 18:15:17 +08:00
hiyouga
2b1e52dcc9 fix #1073 #1462 #1735 #1908
Former-commit-id: cd8e2535aa66931b24b96e76c2b56ce703a579b1
2023-12-20 17:15:40 +08:00
hiyouga
b5238e945a optimize data loading logic
Former-commit-id: 58f669b384582ac90e85de835f1f44f7003f9ec0
2023-12-20 16:15:41 +08:00
hiyouga
afc0f29704 fix #1909
Former-commit-id: f563e8d28dfa48a60cbe3d295b20f9cf58de296d
2023-12-20 16:11:07 +08:00
hiyouga
de0bb1d2da fix mixtral inference #1821
Former-commit-id: 612f9fd19cbd29e8b1785a1576a9668e7dcd264c
2023-12-20 15:11:15 +08:00
hiyouga
cc16ece283 fix #1900
Former-commit-id: 4c35214396f873588562606b084740b6581188d9
2023-12-19 17:21:46 +08:00
hiyouga
31ba802fc9 update readme
Former-commit-id: 36cd747e6a1a568e1a03e6c6611fec48e6ab9df7
2023-12-18 22:29:45 +08:00
hiyouga
4b27cf5460 add codegeex template
Former-commit-id: a8222722b8097158f1c92e3729f41d411eff3926
2023-12-18 19:52:35 +08:00
hiyouga
a53b2a643f add xverse-65B-2 model
Former-commit-id: 3e563a0d9666934dfdab54d61654ec00079a93f1
2023-12-18 19:24:09 +08:00
hiyouga
d925ecae1b add models
Former-commit-id: 3a4728557304996bcbe58d7d6380beead7c63c70
2023-12-18 19:09:31 +08:00
hiyouga
13fd751a78 fix tokenizer for Yi chat models #1617 #1875
Former-commit-id: 9485692c8d367a0b25d3e653db413aa01cb9ad7d
2023-12-18 17:18:11 +08:00
hiyouga
74575f8922 update readme
Former-commit-id: 01267eee0da0bffb3f0c0378e2e60d14e05585c4
2023-12-18 15:46:45 +08:00
hiyouga
5e7bb5fe73 fix llama board
Former-commit-id: f43f61b2898dda56aba0066fcb3409b152260bdb
2023-12-16 22:17:37 +08:00
hiyouga
790a31404a fix #1742
Former-commit-id: efbb32afdcf0d6aa4ca26f54c95f76dbb84f77dc
2023-12-16 20:50:45 +08:00
hiyouga
f927601702 add xverse-65b-chat model
Former-commit-id: fff6288db6b61ca27010ea47c918298f76922106
2023-12-16 20:21:29 +08:00
hiyouga
c4654d54d7 set version
Former-commit-id: 45a05e3a415eeaf875e2cf15bdba0235fbd7d527
2023-12-16 20:17:51 +08:00
hiyouga
df777c30d1 add noisy mean initialization #1815
Former-commit-id: 3253b1fca0123071913079277186c160046edf21
2023-12-16 19:47:51 +08:00
hiyouga
d81ad2d4bc support dpo-ftx
Former-commit-id: 86dfa04f9821556019fa777106787f73eb70b452
2023-12-16 19:21:41 +08:00
hiyouga
9f77e8b025 support autogptq in llama board #246
Former-commit-id: fea01226703d1534b5cf511bcb6a49e73bc86ce1
2023-12-16 16:31:30 +08:00
hoshi-hiyouga
04dc3f4614 Merge pull request #1868 from yhyu13/improve_hfargparser
Improve logging for unknown args

Former-commit-id: 6455013a99ca5c63f5b99c1100e93f794a03c497
2023-12-16 16:06:09 +08:00
yhyu13
7d1fe50977 Use llmtuner logger
Former-commit-id: ef5a560b4246e04e0ef2612e3520e05288e93707
2023-12-16 07:15:27 +00:00
yhyu13
c0e5e3c5d5 Improve logging for unknown args
Former-commit-id: 03e49d76ca91f7fcaf1c013740d5f6bfc11a2028
2023-12-16 05:16:29 +00:00
hiyouga
3a45cfb604 update tips
Former-commit-id: 4432cbda6b7535bcbb40ba77df069fca631b4be8
2023-12-15 23:52:50 +08:00
hiyouga
393e4b0f5a fix #1770
Former-commit-id: 8266187cec70bb4bd1b4837d51b09409ec11f93e
2023-12-15 23:50:15 +08:00
hiyouga
296711d502 support quantization in export model
Former-commit-id: f32500ae6edccab7d14df4c92467e15986866def
2023-12-15 23:44:50 +08:00
hiyouga
9121722999 update dc link
Former-commit-id: f6789e50e17a377b6d9b434d8e12ad99d8eecfeb
2023-12-15 22:11:31 +08:00
hoshi-hiyouga
d8d74091f6 Merge pull request #1864 from hiyouga/dev
Refactor hyper-parameters of adapters and model loader

Former-commit-id: d5ce2fb6858b9f2963f355e9f4d6f046eb6efdcd
2023-12-15 22:06:56 +08:00
hiyouga
33521fb45e fix bug
Former-commit-id: 95ac272907a04a64785f928536de1fd099150f92
2023-12-15 21:54:02 +08:00
hiyouga
e5204e60ed fix bug
Former-commit-id: 8b80baf02cfece53527c27712f0899fa3532c414
2023-12-15 21:49:26 +08:00
hiyouga
0409428d87 add configurer
Former-commit-id: c40c9889615ffb49c7ce24c69c0d3d20d841c800
2023-12-15 21:46:40 +08:00
hiyouga
f902b0d420 refactor adapter hparam
Former-commit-id: f82aece9ebd6df83a7a005cc7cbbcec07fa6e14d
2023-12-15 20:53:11 +08:00
hiyouga
27ef5b1aa7 add loftq
Former-commit-id: 0b900882ef19ac49604a24fbae8b3254f1bff7ad
2023-12-14 21:53:56 +08:00
hiyouga
c32303fc7e fix valuehead model
Former-commit-id: 9f628debb6510f2d1c91b00f121a721ab5d648e9
2023-12-14 20:15:20 +08:00
hoshi-hiyouga
45abe361ba tiny fix
Former-commit-id: 987df4c62f34026adfe2089910f4ff9ac6ebd9a6
2023-12-13 17:32:36 +08:00
hoshi-hiyouga
3ae479faae revert peft version
Former-commit-id: 6440fa1a8c28fd2db58d0905a67d071837e0edd1
2023-12-13 10:49:45 +08:00
hoshi-hiyouga
5698038f49 update peft version
Former-commit-id: 31c01e1272bd2cd9588e5ee68c1924a3dd55c67e
2023-12-13 10:23:51 +08:00
hoshi-hiyouga
020233f725 tiny fix
Former-commit-id: 1478bc052417e0939188f55a0adcbf00956960f2
2023-12-13 10:21:29 +08:00
hoshi-hiyouga
6f9d55b8eb fix #1819
Former-commit-id: f2e2b0354cbe9a7190ccab807f690cc8ab433a6e
2023-12-13 10:14:01 +08:00
hiyouga
2542b62d77 remove loftq
Former-commit-id: e175c0a1c631296117abda2403a4b87bbdd35a66
2023-12-13 01:53:46 +08:00
hiyouga
95678bb6b1 fix sharegpt loading
Former-commit-id: ad35c35f9328bff69e8b9ea7dba6a61a2dc9e28b
2023-12-13 00:56:16 +08:00
hiyouga
a78759e7ee add model urls
Former-commit-id: 3139a9fafab246f5461697efd5ed7a6599d85481
2023-12-13 00:09:17 +08:00
hiyouga
cc5c523f58 update readme
Former-commit-id: e81037d766f89f7e2b6539596397983eba52b492
2023-12-12 23:30:29 +08:00
hiyouga
e39bbdd287 support loftq
Former-commit-id: e7ac2eb7f7daae17525a278ffbe2f82c0fbd8093
2023-12-12 22:47:06 +08:00
hiyouga
d9a50bf93f fix #1795
Former-commit-id: 949ab45487155525789c08027d4f8e7da1b8bc0c
2023-12-12 19:58:34 +08:00
hiyouga
934d00ea1e support system column #1765
Former-commit-id: f425584a511c5e42bae8b3ba090eaa898b28adad
2023-12-12 19:45:59 +08:00
hiyouga
c27675f70d fix modelscope data hub
Former-commit-id: 5b63e8c22538a4788e4b6c8df50e6e6be93ceeac
2023-12-12 18:33:06 +08:00
hoshi-hiyouga
7c9f37c83d Merge pull request #1802 from tastelikefeet/feat/support_ms
Support ModelScope Datahub

Former-commit-id: f73f321e765aab9325673218779ff4ee7f281514
2023-12-12 17:58:37 +08:00
hoshi-hiyouga
b9736c13e0 Merge branch 'main' into feat/support_ms
Former-commit-id: 698756dffb7d4e602b3e0cab66ef0a4befe7215c
2023-12-12 17:55:32 +08:00
hiyouga
c47725ff34 fix webui
Former-commit-id: 15ad266206b12181788db5bb112c2299050d6139
2023-12-12 15:27:40 +08:00
xingjun.wang
3ee3fe0bbb add use_streaming
Former-commit-id: 80388abdb7ee88eb4afad92d8c706370c0574039
2023-12-12 14:23:05 +08:00
xingjun.wang
e54dad75da fix cache dir
Former-commit-id: 6231272b9c51d44196f1fbec026973231e489b67
2023-12-12 14:21:33 +08:00
xingjun.wang
39c2f03eab add print info for test
Former-commit-id: e4ae2fccf0cbec57fb5fb01fd7cc352da69b23bf
2023-12-12 14:14:40 +08:00
xingjun.wang
fb9e1c4087 update cache dir
Former-commit-id: c8a1ce847fd7a75a06659133d92a0ac42e52a839
2023-12-12 13:08:18 +08:00
xingjun.wang
ed26bb3d82 update args for MsDataset.load
Former-commit-id: c5f69357a167cbf99a93607177526e787419ea05
2023-12-12 13:02:54 +08:00
xingjun.wang
0baf32e219 update
Former-commit-id: e15fc417d897c3063a25d6eb7eb89d1916db3cc5
2023-12-12 12:03:23 +08:00
xingjun.wang
79a376d1db for test
Former-commit-id: 33d9082320098f994bfa0c6353459afcb93165b7
2023-12-12 11:52:59 +08:00
xingjun.wang
b634e91c43 for test
Former-commit-id: 95ea942bd32402018e7c5dc61d50153c602ab67a
2023-12-12 11:47:59 +08:00
hiyouga
9e2cc21d04 update readme
Former-commit-id: 42e042a4206aeb5177ddde56386e9655b0c06460
2023-12-12 11:44:30 +08:00
hiyouga
6975124a57 support mixtral
Former-commit-id: 75b5b8e36ab1933b2625f11b645f56cbc805fd85
2023-12-12 11:39:04 +08:00
hiyouga
9f69307db1 fix baichuan resize
Former-commit-id: 66956d13074a9bc74d7a737b9476f38361a7764a
2023-12-11 20:55:50 +08:00
hiyouga
c3448a045c tiny fix
Former-commit-id: 1f839fc4f278c2a258df22899241fc66a2cca682
2023-12-11 18:09:40 +08:00
hiyouga
95c561983c support resize embeddings #1786
Former-commit-id: 368a41bd3c6a04f869083058d9165954fbdad105
2023-12-11 17:50:02 +08:00
hiyouga
7a03c8dab5 use peft 0.7.0, fix #1561 #1764
Former-commit-id: 423947bd58aa50da8785b8ceca1e7e288447a9da
2023-12-11 17:13:40 +08:00
hiyouga
f3ffa8310f fix #1784
Former-commit-id: 4e1af5a5d39d9e2f374c1372e2d67120c63fea09
2023-12-09 20:53:18 +08:00
yuze.zyz
596f496f19 support ms dataset
Former-commit-id: 98638b35dc24045ac17b9b01d08d3a02372acef3
2023-12-08 18:00:57 +08:00
hiyouga
2e6ed731cf fix #1771 and temporarily fix #1764
Former-commit-id: d0e5a5d604e16c2fe0035b0ac1d54dc3625d4da3
2023-12-08 16:26:20 +08:00
hiyouga
24ce319b6f add models
Former-commit-id: 758ae7937a41a95016e70180fb343011763c1b67
2023-12-06 13:33:18 +08:00
hiyouga
7b7bfea37d fix ppo trainer save logic
Former-commit-id: 5e70c41e4e12a1109570b0ff56346fe212c028ed
2023-12-04 19:00:19 +08:00
hiyouga
3be461260a update readme
Former-commit-id: a15f8cf19cac42acfb9917a2d7c9fa36a838b360
2023-12-04 11:22:01 +08:00
hiyouga
8dab8d9831 update readme
Former-commit-id: d3c46cb126a9182be765341fe31c860d71430712
2023-12-04 11:02:29 +08:00
hiyouga
fb4c5f3c91 fix #1715
Former-commit-id: 3f9192dbbbafdc2171d2eb80282d5cae47565b7b
2023-12-03 22:35:47 +08:00
hiyouga
5fe3cce5a3 release v0.3.3
Former-commit-id: 72ddb5fcce1649599671de214667d8d899ef5203
2023-12-03 21:59:45 +08:00
hiyouga
09f165d442 fix bug
Former-commit-id: 2fd7a8fc3134af66193a5e8db8fea35025f82de9
2023-12-03 21:40:40 +08:00
hiyouga
60aea7521b ppo support rm server
Former-commit-id: 20b0edf16f5b42cb2c4a795674647afb68cb3a4a
2023-12-03 21:38:51 +08:00
hiyouga
29545d0e5e implement rm server #1543
Former-commit-id: 2e5bb6888c86079493456c2ddd525f8c52b9963e
2023-12-03 20:52:54 +08:00
hiyouga
4a14099cfd fix #1707 #1710
Former-commit-id: 243a596518ad69cf1eec20a082534b9e94353ce4
2023-12-03 11:33:12 +08:00
hiyouga
b052574ddf add logo
Former-commit-id: 597894ad31c186120335252ccc0cc48fcea701b4
2023-12-02 01:31:24 +08:00
hiyouga
5ea6a7c6d6 fix #1642
Former-commit-id: 11be28201f688ac21cf94135067d37e9aa7ab0a1
2023-12-02 00:37:53 +08:00
hiyouga
8ca196d51f add xuanyuan models
Former-commit-id: 1dfa9de3723550cddf24bbc0739cad6207731212
2023-12-02 00:35:29 +08:00
hiyouga
5f572cbd77 fix gptq training
Former-commit-id: bec58e3dc575aa4247e563881a456328ee5ef496
2023-12-02 00:27:15 +08:00
hiyouga
679bd3ab30 tiny fix
Former-commit-id: fd2782a06ba4efa76cacbb49eb76a05de8d8aca6
2023-12-01 23:37:10 +08:00
hiyouga
da3d59fada fix gptq model inference
Former-commit-id: f7da9a87cb48cacb7d56322817b05d6f471f6508
2023-12-01 23:34:14 +08:00
hiyouga
835d27151d update readme
Former-commit-id: a0a9408e11f6b4cfb39af3f28402353b7cf48fa6
2023-12-01 22:58:29 +08:00
hiyouga
f1d7228a74 fix #1703
Former-commit-id: eee2e9abf6df345c5471e8ca7639293543ba720c
2023-12-01 22:55:41 +08:00
hiyouga
72bbd5bdef patch modelscope
Former-commit-id: 8888cf53f040f5a2d8c0e59cddf79b252449bf58
2023-12-01 22:53:15 +08:00
hoshi-hiyouga
ad9d866547 Merge pull request #1700 from tastelikefeet/feat/support_ms
Support ModelScope hub

Former-commit-id: f79c3b663a91ac2a7cdcf71192b6dd84f110b8f1
2023-12-01 20:25:18 +08:00
hoshi-hiyouga
a1ec668b70 Merge branch 'main' into feat/support_ms
Former-commit-id: b8954342611e24bc3af972747fd016cde89eee3f
2023-12-01 20:23:46 +08:00
yuze.zyz
389687a56d remove useless code
Former-commit-id: 323df46dd6a8eaf1fd608380406dcbce80c097b2
2023-12-01 17:28:23 +08:00
tastelikefeet
97280c73b9 fix bug
Former-commit-id: 6d483e76141420e0cb577541e6e1794c20f025f6
2023-12-01 17:27:00 +08:00
hiyouga
f3c622b665 fix err hint
Former-commit-id: 935a4a01bd9204129dd72a500ed75b268714d1e8
2023-12-01 17:13:22 +08:00
hiyouga
d71e8d8dbf add err hint
Former-commit-id: 2cf0249ec6f7524c39a6c8df73593f6d25b665b7
2023-12-01 17:04:37 +08:00
hoshi-hiyouga
02c2089ac8 Merge pull request #1699 from Samge0/patch-1
Update .gitignore

Former-commit-id: ab9da1bc5043fedeac8e57614e5986ebdd2128af
2023-12-01 16:52:57 +08:00
SamgeShao
07ad28a053 Update .gitignore
Former-commit-id: b2ec86ef63683665382c2fda142c3d9743e3c8a7
2023-12-01 16:37:41 +08:00
yuze.zyz
d323ccc3ec add readme
Former-commit-id: 3d5ec6f12b4ae7d04520e6865516a9a6dd4f7efe
2023-12-01 16:11:30 +08:00
hiyouga
4738d002c7 tiny fix
Former-commit-id: 37aa7099dff2a9a7b52e259dac92de41ce606946
2023-12-01 15:58:50 +08:00
hoshi-hiyouga
ec099b0586 Merge pull request #1695 from Samge0/dev
Improve:"CUDA_VISIBLE_DEVICES" read from the env

Former-commit-id: b49cde0c29774820dcf4463e3f1ef00114af7219
2023-12-01 15:56:18 +08:00
hoshi-hiyouga
a51253fea2 Merge pull request #1690 from billvsme/main
Improve get_current_device

Former-commit-id: c3b8cc27c91248a7381b3333abf099064412dc1a
2023-12-01 15:44:35 +08:00
hiyouga
304ec9ec6a fix #1696
Former-commit-id: 722ae14a652af34d9b91f9459e613d7959ecaa7e
2023-12-01 15:34:50 +08:00
tastelikefeet
8547085615 add model
Former-commit-id: 48e8d8438bc6cd2c75dc39419c45aaebb34a2e0a
2023-12-01 15:06:17 +08:00
samge
14b139ecb5 Improve:"CUDA_VISIBLE_DEVICES" read from the env
Former-commit-id: 7a61daa8be76779c876d685c57c464133ca70752
2023-12-01 11:35:02 +08:00
billvsme
7b45f5068f improve get_current_device
Former-commit-id: 2b07815e7fc8dc6ad0a7e9eccdd6681fbab35f3c
2023-11-30 22:40:35 +08:00
hiyouga
99ceee840e fix #1597
Former-commit-id: d77a3a79a0e854803a57af8ac6a7246691f69f70
2023-11-30 21:47:06 +08:00
hiyouga
8ed68301e3 fix #1668
Former-commit-id: bccc71259e703ca1e1d88169e385a026c4efa92e
2023-11-30 21:02:00 +08:00
hiyouga
664267e050 fix #1682
Former-commit-id: 06d56696731eadbeeea615eae4efce1b6c36def4
2023-11-30 20:03:32 +08:00
hiyouga
7ef8f46591 add models
Former-commit-id: b9eaadde8b5f4b9f89fa7bb910b325fcf9c84434
2023-11-30 19:16:13 +08:00
yuze.zyz
6933c1fed2 fix
Former-commit-id: e8774b4c9cbc8f894621ec72957f720d5c83d22b
2023-11-29 21:43:58 +08:00
yuze.zyz
9d125bf533 support ms
Former-commit-id: fdd4f94f563110ef9f96ab4a7fd954def32e9785
2023-11-29 20:36:55 +08:00
hiyouga
08d5340bd8 add gpu requirement #1657
Former-commit-id: 8581a9133790573031d9615a551fb677eb3be461
2023-11-29 12:05:03 +08:00
hiyouga
0e6f4f981e fix #1658
Former-commit-id: 3126687c4820c34daa6a2e9e3bf9065ad59e92dc
2023-11-28 20:57:24 +08:00
hiyouga
670ee3934f fix #1659
Former-commit-id: e4123129aae59f4123d53c1f5320e3d5e09ae26d
2023-11-28 20:52:28 +08:00
hiyouga
569860d7ac support export size setting
Former-commit-id: 1a4de54586c21cdbbc89f8a716ca5a54c87a6120
2023-11-26 18:34:09 +08:00
hiyouga
953a562ec1 support Yi-34B-Chat models
Former-commit-id: 1751a79c27e7fc13e76a731a061dc0c10d828cda
2023-11-23 19:31:49 +08:00
hiyouga
7f54008d3c update readme
Former-commit-id: 561481a8008fde5a3273558460193864a09866ed
2023-11-21 13:15:46 +08:00
hiyouga
5f5959bc33 set version
Former-commit-id: 6b47ad74c7b3099f9b5087c73db4aee42c451297
2023-11-20 22:57:44 +08:00
hiyouga
0105cd48f2 support GPTQ tuning #729 #1481 #1545 , fix chatglm template #1453 #1480 #1569
Former-commit-id: fdccc6cc9b68890199e9250cabdb996ff2f853b9
2023-11-20 22:52:11 +08:00
hiyouga
28258aecd2 update ppo trainer
Former-commit-id: caa525a5c6f228b9ad71387d1fe4f1c2ffa2479e
2023-11-20 21:39:15 +08:00
hoshi-hiyouga
e585950c54 Merge pull request #1553 from hannlp/hans
Change the default argument settings for PPO training

Former-commit-id: 1b64678fa4979485f67c3bb1420dfdff6fcbc6e7
2023-11-20 20:32:55 +08:00
hiyouga
bcd661afa6 fix value head model resuming
Former-commit-id: ccf0b65d886c09c7c49977c43b0544fe1bfcc258
2023-11-20 19:01:37 +08:00
hiyouga
adf2730d1d fix #1567
Former-commit-id: 8c01ffe8d277d49a413571e0669f460c8d0802bf
2023-11-20 18:46:36 +08:00
hiyouga
ba2be6371d better data streaming
Former-commit-id: 65ac8e84fd6f22255c587b20382fdf5d8131d015
2023-11-19 23:32:47 +08:00
hiyouga
d2ff09a404 fix model card network issue
Former-commit-id: 36155cd1893bea036f15c648c06b0047c02dfb4f
2023-11-19 23:03:19 +08:00
hiyouga
9f364d3880 fix Mistral template
https://github.com/lm-sys/FastChat/pull/2547

Former-commit-id: d426ecdf6e95402fc36893f7e4f17f881e1b957b
2023-11-19 16:29:30 +08:00
hiyouga
cfad41b901 fix #1263
Former-commit-id: faff5d32621f187ebd3124d7ade04e3fa437c53e
2023-11-19 16:05:18 +08:00
hiyouga
6889f044fb fix #1558
Former-commit-id: 263b2b24c8a649b51fa5ae768a24e67def8e0e96
2023-11-19 14:15:47 +08:00
hiyouga
3d1ee27ccd fix evaluator and cached_file in 4.31.0
Former-commit-id: 970897da402f604220d45084d492de4dab809ba4
2023-11-18 19:39:23 +08:00
hiyouga
775ce62950 update benchmark
Former-commit-id: 1cd2ae910e3ffca92978772d000de6fde2f6bb13
2023-11-18 11:30:01 +08:00
hiyouga
821a6f2fa6 update readme
Former-commit-id: a4d86a4bea1cce2219a54def9dfd3fd732d48e72
2023-11-18 11:15:56 +08:00
hiyouga
5197fb2fad add benchmark
Former-commit-id: 85a09cb649be740a47359371499d821ee0d5c81e
2023-11-18 11:09:52 +08:00
hiyouga
92abe91d22 update dataset
Former-commit-id: a310b22b446118d90dd73906847ed3d01a574b50
2023-11-17 23:19:12 +08:00
hiyouga
a7bf0b85d7 fix quantization
Former-commit-id: 8268aefe8fba268065e24ffe159a9c49f7c6f3a5
2023-11-17 22:21:29 +08:00
hiyouga
5ce5ea84a9 fix #1550
Former-commit-id: c12acd21a5a500892ed739c79327ccd39fddad5b
2023-11-17 17:23:13 +08:00
Yuchen Han
992be39f90 Update README_zh.md
Former-commit-id: 3e8a17c92d700bcafbe6559ea689dc4c0ad0481a
2023-11-17 00:18:07 -08:00
Yuchen Han
cab80a3c56 Update README.md
Former-commit-id: c1532dc6fe5d5b427011bd5509a2bc44ee16d951
2023-11-17 00:17:36 -08:00
Yuchen Han
6af7107938 Update workflow.py
Former-commit-id: f70b7ffe6442217a222e0ef797c407f259a13886
2023-11-17 00:16:27 -08:00
Yuchen Han
bcd31cf245 Update finetuning_args.py
Former-commit-id: 30e3430553f1f7e09cd57ef2c9843b549746c618
2023-11-17 00:15:51 -08:00
hiyouga
85c4ccfef9 fix packages
Former-commit-id: c93175d18ad9a4b7b61629153acabf8d0c978dfc
2023-11-17 16:11:48 +08:00
hoshi-hiyouga
dc0f81aabc Merge #1544 from Outsider565/main, fix #1548
Fix: Change rouge-chinese package name to rouge_chinese
Former-commit-id: c24da51cb5d3f78d54dcbfb31b565fcac4783a76
2023-11-17 16:09:42 +08:00
Shaowen Wang
07f934566a Fix: Change rouge-chinese package name to rouge_chinese
To reproduce:
python:
importlib.util.find_spec('rouge-chinese') -> None
importlib.util.find_spec('rouge_chinese') -> ModuleSpec(name='rouge_chinese'...)
from rouge_chinese import Rouge
print(Rouge.__module__) -> rouge_chinese
Former-commit-id: a78b11d944b6cb7dbe2a1d8a24d240e196aa530a
2023-11-16 20:12:35 -06:00
hiyouga
77cb18e9e3 fix chatglm template
Former-commit-id: 6a4b79c2e0610a17012bf3e72a2b5e8bac060092
2023-11-16 22:54:15 +08:00
hiyouga
fccaecf730 Update bug-report.yml
Former-commit-id: 92ed2297c78d016113fa7f90cedc0933a0bb2be0
2023-11-16 19:37:35 +08:00
hiyouga
53cdfe8f73 add issue template
Former-commit-id: 4ca01a6b051043593541403d74e4d464b70e0e4b
2023-11-16 19:35:30 +08:00
hoshi-hiyouga
ea03523c6a Update issue templates
Former-commit-id: f967abcfcd052b65745f20e2c760ca45c412b66a
2023-11-16 18:56:30 +08:00
hiyouga
caf3cbf8d7 fix web ui demo
Former-commit-id: e566a68a27872f730b111078977048755ec74a40
2023-11-16 18:41:55 +08:00
hiyouga
da411066c9 fix web ui demo
Former-commit-id: 6fead193fe44fec74c2262d8653ed2f6006fac36
2023-11-16 17:12:23 +08:00
hiyouga
95d0f77fc2 release v0.3.0
Former-commit-id: de7f5b622340ab09ebbe57ad2703e63d06dfdeea
2023-11-16 16:00:11 +08:00
hiyouga
9b2654277b update readme
Former-commit-id: 4018aabc5d1623033d27a8aced25804de79b7e7b
2023-11-16 15:58:37 +08:00
hoshi-hiyouga
f1b3bdac3f Merge #1525 from hiyouga/dev, fix #224 #336 #931 #936 #1011
Refactor llmtuner, support full-parameter RLHF

Former-commit-id: 3b92826803dc69471827b4f8204c2c3dc5310619
2023-11-16 15:47:13 +08:00
hiyouga
595fdbd95d fix css
Former-commit-id: 7afec127f60257462828298b25a5f6fd9c6f42c5
2023-11-16 15:45:38 +08:00
hiyouga
dab9385297 fix bug in web ui
Former-commit-id: a598f145ec903dd2b2c984d951b6c450b142ece5
2023-11-16 15:21:24 +08:00
hiyouga
df83def566 update ppo and demo in webui
Former-commit-id: de7571704c82121db13e3fc907379d2453100191
2023-11-16 14:55:26 +08:00
hiyouga
f9d4e37b3c fix bug in freeze tuning
Former-commit-id: f6b436a08421ca17d64abc51497f4aa43729a43b
2023-11-16 14:25:11 +08:00
hiyouga
e59a3d71e0 tiny fix
Former-commit-id: d65519d8a44b73bbb713741c23465f13c35c83f5
2023-11-16 03:27:19 +08:00
hiyouga
de3a84ac59 fix rlhf callback
Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
2023-11-16 03:26:19 +08:00
hiyouga
e017266b98 fix bug in PPO training
Former-commit-id: 2e99f0e53ce6de0acbcab85dd50aef874e8c6336
2023-11-16 02:32:54 +08:00
hiyouga
f81a8a5e5c fix import bug
Former-commit-id: 2356029cdd120d5f7bf630b80681ce8c53bff90d
2023-11-16 02:27:03 +08:00
hiyouga
7a3a0144a5 support full-parameter PPO
Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
2023-11-16 02:08:04 +08:00
hiyouga
8263b2d32d add demo mode for web UI
Former-commit-id: 5ad34f08b4e1505d7933b973497347f126b2e818
2023-11-15 23:51:26 +08:00
hoshi-hiyouga
833cd490b8 Create CODE_OF_CONDUCT.md
Former-commit-id: 6bee64cdf9c75488033e600fb5b48738daa1ed3b
2023-11-15 20:42:15 +08:00
hiyouga
2162c37e41 update readme and constants
Former-commit-id: 7d83e3dd9101a4fdd0b589d0c1f7b609c0feecd1
2023-11-15 18:04:37 +08:00
hiyouga
b2ac8376e1 support multiple modules in freeze training #1514
Former-commit-id: 60abac70dfd778df2ae8b3a2e960ed8b607d7ab6
2023-11-15 17:08:18 +08:00
hiyouga
8079584143 fix imports
Former-commit-id: 6156f1abef631c675d150dd1cb0325cfc3820c91
2023-11-15 16:47:45 +08:00
hiyouga
09a4474e7f disentangle model from tuner and rename modules
Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
2023-11-15 16:29:09 +08:00
hiyouga
81530133ff fix #1507
Former-commit-id: 1ba9c53bd9743fa95fca1516c0ed9da352dbe9a1
2023-11-15 16:22:32 +08:00
hiyouga
cc4b384ac3 Update cal_lr.py
Former-commit-id: b92ef6c80ae108982046ec1419efb67c8b10b250
2023-11-14 21:14:42 +08:00
hiyouga
3852daf447 Update cal_lr.py
Former-commit-id: b6c3f9b24324403db41c5680a00aabc6d53bbeb9
2023-11-14 21:13:01 +08:00
hiyouga
5c97111f9d Update cal_lr.py
Former-commit-id: 1258eec806f6f4580a6eb7d9eb44f431f4c0da4f
2023-11-14 21:09:30 +08:00
hiyouga
75dd1f0f7e add cal_lr.py
Former-commit-id: cea2ba17efc47917e63437a376f220864f7f90dd
2023-11-14 20:58:37 +08:00
hiyouga
c9a4551012 fix #1494
Former-commit-id: 07c8d734529f03e47ef638a1bda222e8824d3d38
2023-11-14 18:07:20 +08:00
hiyouga
87197ba91d fix #1489
Former-commit-id: ebdeaca9cdfd6138c690a0fcb9f676deaddff177
2023-11-14 15:27:05 +08:00
hiyouga
7461bf84e5 support eval remote dataset
Former-commit-id: 71dd2698bf8c0b9ef7af995fb1e49e39fa66074e
2023-11-14 02:42:30 +08:00
hiyouga
fbc0357b2e fix dc link
Former-commit-id: 04c3a1f1c98d8f191102e359def0c8dcdc9621e3
2023-11-13 23:22:56 +08:00
hiyouga
ec334f5891 release v0.2.2, fix #1478 #1466
Former-commit-id: c9534c411716e1dceb54c5eb35fe845c93ee2973
2023-11-13 23:09:05 +08:00
hiyouga
885efe772e fix #424
Former-commit-id: ca24d445f825e120e659f5cd080a954c2243b8f2
2023-11-13 22:42:23 +08:00
hiyouga
64fc9ba678 refactor evaluation, upgrade trl to 074
Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
2023-11-13 22:20:35 +08:00
hiyouga
989eccd286 fix flashattn warning
Former-commit-id: 6eb095d39bd82fdbdb729a0ea57fc7246e3a60d6
2023-11-10 18:34:54 +08:00
hiyouga
f0766a2ab0 add todo
Former-commit-id: 0bd884feb11736d0ab24ca19885151cb47d9dcd3
2023-11-10 14:38:18 +08:00
hiyouga
178b85ff9a refactor constants
Former-commit-id: a4d4c3fd35276f20e3b354e9d13ea971029c8775
2023-11-10 14:16:10 +08:00
hiyouga
68dd1ef121 tiny fix
Former-commit-id: 97ba2027bb1ddc01a3c824c40d5a180828810c2c
2023-11-09 17:20:49 +08:00
hoshi-hiyouga
b222cffe98 Merge pull request #1454 from yyq/main
Update finetuning_args.py

Former-commit-id: e67d8b93705383a8590f99e26e9fe8f663712aef
2023-11-09 17:12:18 +08:00
Yanqing
b4f1ab93d1 Update finetuning_args.py
更新 chatglm/falcon/bloom 的 lora_target 的名称

Former-commit-id: 06606739af035a80ae9ddba9d12c965ed289305d
2023-11-09 17:04:40 +08:00
hiyouga
f2e139f5cd fix #1452
Former-commit-id: 4d16214467715df458e24d03bb7d303d62b8bdcd
2023-11-09 16:41:32 +08:00
hiyouga
a9cbca1604 update readme
Former-commit-id: f7ead54042868550a3e8a6928ea3c0e2673f15b3
2023-11-09 16:00:24 +08:00
hiyouga
3a30ce6c16 release v0.2.1
Former-commit-id: 1c30f2be0140f5ab47c2bc811170d0271a0cdad6
2023-11-09 15:54:16 +08:00
hiyouga
48ec5355f9 add template, modify datasets
Former-commit-id: 81e54beb4d0f792f4fd7f450643caaf10f2f0b7d
2023-11-09 15:53:23 +08:00
hoshi-hiyouga
11859bc322 Merge pull request #1436 from lvzii/main
fix tokenizer config changed after pretrain

Former-commit-id: f485c3983e413fd3a3a57b451800705b072869a7
2023-11-09 14:30:50 +08:00
hiyouga
28c67a5be8 support parquet format #1446
Former-commit-id: 44a3b9ac9f10d2012b8ad3d8c48123db9a0da2f1
2023-11-09 14:17:40 +08:00
hiyouga
44fe93e9b0 fix #1438 #1439
Former-commit-id: 84260d58dda22adc32c26bc943ed2a36fd01341d
2023-11-09 13:45:10 +08:00
lvzi
09a1681b63 fix tokenizer config changed after pretrain
Changing tokenizer's attribute at preprocessing stage will result in saving a wrong tokenizer.
for example, baichuan2

Former-commit-id: 19942b5314b84267691f0a5657d0679f2ddbe58b
2023-11-08 15:50:46 +08:00
hiyouga
f5ba2190fb fix ppo train and dpo eval
Former-commit-id: ced863031836632cb5920e22ae6991f251372118
2023-11-07 22:48:51 +08:00
hiyouga
14a38b5069 fix #1422
Former-commit-id: 25d7bbd0a5142f001bd2ff498df07b24137050a9
2023-11-07 19:42:01 +08:00
hiyouga
f23e5b602a fix reward model loading
Former-commit-id: 9709ca501180a1afce32e9043aedb359762b437d
2023-11-07 17:20:51 +08:00
hiyouga
857696ed9c fix args
Former-commit-id: 44d0fa2ac6a6423c7ddaf91eb8998c1b9248c04e
2023-11-07 16:36:06 +08:00
hiyouga
2084133058 update info
Former-commit-id: 89643b8ac1e3fa8d2f29f1c88e4d4503410c0d05
2023-11-07 16:28:21 +08:00
hiyouga
f7f0c3070e delete file
Former-commit-id: 7d6355db0fd5809b99f3fa42753cf4dffd251fd1
2023-11-07 16:20:12 +08:00
hiyouga
46235aa514 fix #1418
Former-commit-id: 9bfecc72c53cf95fea4a9ff02ec40a65da6d4f54
2023-11-07 16:17:22 +08:00
hiyouga
2eb65d21ac upgrade peft, fix #1088 #1411
Former-commit-id: aa7d104f8e050d12cb8f585bc8a52c850995500f
2023-11-07 16:13:36 +08:00
hiyouga
37a0d62a82 update requirements
Former-commit-id: 82ebbbbb80b3f3f616274210970738d0f44b5a0a
2023-11-06 19:01:21 +08:00
hiyouga
21ac46e439 use seed in evaluate.py
Former-commit-id: ab5cac1dfa681933f3266827f80068ce798b4c56
2023-11-06 18:17:51 +08:00
hiyouga
ba3e8ba20c update readme (list in alphabetical order)
Former-commit-id: e6a67b5477ee095bd92764581cfe6af57e799a69
2023-11-06 17:18:12 +08:00
hiyouga
2c48e798ca update templates
Former-commit-id: 85be2e242b062283f192c4c4d0715dc1e8a68589
2023-11-06 12:25:47 +08:00
hiyouga
4e40f5b62b fix #1383
Former-commit-id: 9b8a782aa80f27c3e2a2e2621f9be17cae1a27e8
2023-11-06 11:42:23 +08:00
hiyouga
2a8892b785 fix deepseek template
Former-commit-id: 1fdbcdad9a1cdb20299350efd87a8e5cb8c625a3
2023-11-05 13:08:46 +08:00
hiyouga
ee3b33ff03 support deepseek coder #1378
Former-commit-id: ae0c829917b9de10e71199c85c77a52cdcd2b7b3
2023-11-05 12:51:03 +08:00
hiyouga
b2c3001f8e fix #1365
Former-commit-id: 0277d120e62164bb7fa1d6043b8fcc52c881fe96
2023-11-05 12:21:07 +08:00
hiyouga
6cfe1e1ac2 tiny fix
Former-commit-id: 594c510a20d6c2782d7b7ffff18931e3003e6c22
2023-11-03 01:26:06 +08:00
hiyouga
52326870e4 fix #1290
Former-commit-id: ad911d258c4cea16f54d09bc192e076c21d26394
2023-11-03 00:44:53 +08:00
hiyouga
217fde0918 fix bug in data loader, support dpo eval
Former-commit-id: f4f3dcff990468a2fa864b7176adcebbcf16dac9
2023-11-03 00:34:26 +08:00
hiyouga
065021d82a update data readme
Former-commit-id: 6a65ef44ed58714c611da60b5af96b85352e8735
2023-11-03 00:15:23 +08:00
hiyouga
4bb643e685 update data readme (zh)
Former-commit-id: b32fb3a984c681732b82f6544d6c05a98c34cf4c
2023-11-02 23:42:49 +08:00
hiyouga
b77c745b1a support sharegpt format, add datasets
Former-commit-id: 202daf8987ccb7523be03ca535b572b5c9e65994
2023-11-02 23:10:04 +08:00
hiyouga
7d13501b94 support pagination in webui preview
Former-commit-id: f2307e26b9c2ce5d60917cce5a9638466ea676c8
2023-11-02 21:21:45 +08:00
hiyouga
ac74639b32 fix webui
Former-commit-id: 9192948fa221c0275ddfa579ef6b3442d45b8962
2023-11-02 18:03:14 +08:00
hiyouga
12fa56ae68 support warning in webui
Former-commit-id: 9903b523fad2f0ec0e66c3d313823bd4674bfa2b
2023-11-02 17:57:04 +08:00
hiyouga
f11b863f4b fix #1349
Former-commit-id: 556c023eab2a68560b26a7d5318a79410fb0c700
2023-11-02 17:02:44 +08:00
hiyouga
f3e4b72957 fix #1356
Former-commit-id: d2ed436108a339d405dad1be1ca15baca3d6d3e4
2023-11-02 16:51:52 +08:00
hiyouga
8d52fb46ca fix #1325
Former-commit-id: 59f2cbbd52d4646fbd1ba83032bf522ecc49a50f
2023-11-01 23:38:49 +08:00
hiyouga
dab8f45033 fix chat
Former-commit-id: 68f2b3df09c4c8638b9e225fd5b8aed3541e97a0
2023-11-01 23:07:58 +08:00
hiyouga
bff8b02543 update gradio, support multiple resp in api
Former-commit-id: a34263e7c0e07a080276d164cdab9f12f1d767d2
2023-11-01 23:02:16 +08:00
hiyouga
2406200914 fix SFT trainer
Former-commit-id: bf09b6a6cd75cc2738d9af6b8c30bcbba77fa9b5
2023-10-31 21:52:52 +08:00
hiyouga
db06fcfc84 fix #1316
Former-commit-id: 88a753fe80e277007bac2264aee24024e18f2314
2023-10-31 11:32:08 +08:00
hiyouga
93b9f74e9f update projects
Former-commit-id: 33d58e9171ad2693b9d54715eb61a6f4326c59f4
2023-10-29 22:53:47 +08:00
hiyouga
33ec844f76 add projects
Former-commit-id: 495a68cd5962dd3b3af7e4a920d91ac25531a862
2023-10-29 22:07:13 +08:00
hiyouga
0f727b393e update constants
Former-commit-id: ebacbb1072045924a7e335cc9dda488d6f0be8b3
2023-10-29 13:30:20 +08:00
hiyouga
7da2aad6ee fix vicuna template
Former-commit-id: a98eda0803e4b73a24f12d848e14161451921e98
2023-10-27 22:15:25 +08:00
hiyouga
6f09f50d02 fix chatglm3 template
Former-commit-id: 69bcbc9f6c98e4f4ad97ec0306b33ab21923d311
2023-10-27 21:12:06 +08:00
hiyouga
5919832059 update readme
Former-commit-id: 6fb92c7088316c56ce8656e540fc47b0a5a1bf18
2023-10-27 19:19:03 +08:00
hiyouga
f7635c1afc support chatglm3
Former-commit-id: ba82e13bbeed3b262d301196b1860d73f319401d
2023-10-27 19:16:28 +08:00
hiyouga
c762168ed0 support dataset cache
Former-commit-id: f79ee62eb4a2a4a01cb4e2a6aa2d07158cf8eb59
2023-10-26 21:48:45 +08:00
hiyouga
67a46e553f fix #1287
Former-commit-id: d885aca472c6448bbf9a9e8d16bead92038825e3
2023-10-26 17:49:41 +08:00
hiyouga
e406f37b54 fix #1285
Former-commit-id: 2f8fe4439506e844b147fe38b5eb878c5748c31c
2023-10-26 16:34:52 +08:00
hiyouga
62fe877124 remove filter in preprocess
Former-commit-id: 9eac08b35fec47129a29c401ca265343f8388ab0
2023-10-23 23:46:02 +08:00
hiyouga
a0e682ba79 update neftune logic
Former-commit-id: bb4f0589ed23bf0236d3e918272ad64f0a05ef39
2023-10-22 17:42:13 +08:00
hiyouga
49e8a87383 fix webui
Former-commit-id: a5a5a7bc1f53d36e1b26e418999465903cb7d9ed
2023-10-22 17:24:56 +08:00
hiyouga
b2764b49ca add new options in webui
Former-commit-id: 6698b832dd9cc2d7d60be4fa5ab90e34a7e9d8e0
2023-10-22 17:17:58 +08:00
hiyouga
06b810de8f fix recursion error
Former-commit-id: c7938188c36a71a878bca982b7dd151195164986
2023-10-22 16:28:37 +08:00
hiyouga
6da51565f5 reimplement neftune
Former-commit-id: efe9e5a194d3a9f052701d904715238816e4c09e
2023-10-22 16:15:08 +08:00
hoshi-hiyouga
1f69965239 Merge pull request #1252 from anvie/neftune
add NEFTune optimization

Former-commit-id: 85d5c5fbe731f486c3e83812227fa05edc131487
2023-10-22 15:59:20 +08:00
anvie
af2d61178d add NEFTune optimization
Former-commit-id: 603e0298af64116ac07130fe6661a9ba823c186c
2023-10-21 13:24:10 +07:00
hiyouga
6a955ccf4f fix openchat template
Former-commit-id: 88b9b657bc50495ac4c42f64195fc652fe4ca3df
2023-10-21 01:25:42 +08:00
hiyouga
c0658711ca fix tokenizer padding side in evaluate.py
Former-commit-id: bcb43ff8ba1946c1f7e7865c9d0fb47ba276935d
2023-10-21 00:30:04 +08:00
hiyouga
d602f06882 fix #1232
Former-commit-id: 49975755d47344e362145c52548fdda8783f2c0c
2023-10-20 23:28:52 +08:00
hiyouga
1cb9a38ac2 fix #1215
Former-commit-id: d91b43a8afbea4859357f2224e3d9b9d71160e6d
2023-10-19 16:19:21 +08:00
hiyouga
47a1f73d0f fix #1218
Former-commit-id: b301f35bd4a3bf368159c8f5fb4e2736f922115b
2023-10-19 16:17:41 +08:00
hiyouga
142dd63b47 fix #1228
Former-commit-id: e4e0cae3f55da2f1b566c97dbfdd7fc5b7b728a4
2023-10-19 15:54:10 +08:00
hiyouga
b1bd8370c2 fix #1217
Former-commit-id: 065fc0a6f3f005bb87e1c5c126c8b6bb470ce700
2023-10-19 15:52:24 +08:00
hiyouga
215660c8da rename webui
Former-commit-id: 26feaf80fff6177d9eb4e28ad18feb6d34d3ea27
2023-10-16 15:16:24 +08:00
hiyouga
0cafe67efe fix #1197
Former-commit-id: 00100e23fcfef9587fda4cf01c62599d996e1176
2023-10-16 15:13:46 +08:00
hoshi-hiyouga
ea83b3222b Update README_zh.md
Former-commit-id: 3450404bb9a33c3bd4b45ac4afcf51062f8c7d1d
2023-10-16 00:28:27 +08:00
hoshi-hiyouga
725087a04f Update README.md
Former-commit-id: d84896597eded79f78224faed81cc9f2df222978
2023-10-16 00:23:37 +08:00
hiyouga
d627ab4855 release v0.2.0
Former-commit-id: 7f941c1ab6c52915aa2675fa77cae5efc530fdd9
2023-10-15 20:49:43 +08:00
hiyouga
7d867e8df4 update readme
Former-commit-id: a99a92b129a3d2372e66ca73b87c3e521f144043
2023-10-15 20:28:14 +08:00
hoshi-hiyouga
3d34d44497 Update README.md
Former-commit-id: e6fcc1831dadd2ec2c0acb14697a35f6471139ab
2023-10-15 20:23:22 +08:00
hiyouga
a6f800b741 fix config, #1191
Former-commit-id: 5dbc9b355e85b203cb43ff72589374f0e04be391
2023-10-15 18:28:45 +08:00
hiyouga
a003d1fa1e disable tqdm in webui mode
Former-commit-id: 832be571bec2eefb79ea88f110b7827f5c1249e6
2023-10-15 16:18:25 +08:00
hiyouga
c2e84d4558 refactor export, fix #1190
Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
2023-10-15 16:01:48 +08:00
hiyouga
68330eab2a fix eval resuming in webui
Former-commit-id: b28b53cd06777f213ef7b925a914ff5fd357ade1
2023-10-15 15:45:38 +08:00
hiyouga
7070f3969d tiny fix
Former-commit-id: 47b7b34357708a5354d542ddc239146c6417d718
2023-10-15 05:02:48 +08:00
hiyouga
e4727ab155 fix callback
Former-commit-id: 51208655a8c1d66551b7b644247321a3583debdc
2023-10-15 04:59:44 +08:00
hoshi-hiyouga
280e7d97ad Merge pull request #1186 from hiyouga/dev
Support Web UI resuming training

Former-commit-id: fcbecd0c4cb17b883e9b780a71d2abc38228293e
2023-10-15 04:53:14 +08:00
hiyouga
31e3805fb8 implement webui resuming training
Former-commit-id: 2d41672ef52414c56c50c8b4fdc442797ba682e9
2023-10-15 04:52:19 +08:00
hiyouga
ef248dbe15 fix bugs in webui
Former-commit-id: 4befa74ea630d90e4d7a1f7d7c34d39257717ec1
2023-10-15 03:41:58 +08:00
hiyouga
6a61b4b638 refactor webui
Former-commit-id: 813ecd8e51949c21ab6fbaa51cc2b1a84ee07952
2023-10-15 03:06:21 +08:00
hiyouga
4b1473502f fix loading dtype
Former-commit-id: d54a356128f7e335c12089702cf3de7f5b4baf16
2023-10-14 20:15:24 +08:00
hiyouga
bf211d818d fix #1176 #1177
Former-commit-id: 5627a2b57c270a78095a32083e2dc7aa02162875
2023-10-14 20:00:17 +08:00
hiyouga
27dd87c890 fix #1184
Former-commit-id: 5b069a967823e659dbc70b0d50361b3ad248087e
2023-10-14 19:20:11 +08:00
hiyouga
8659084ab0 fix webui
Former-commit-id: a0fe43aac968d9f6ca4724b8d718b45c03063b91
2023-10-13 16:27:59 +08:00
hiyouga
e1c9dcea93 update readme
Former-commit-id: 9d9018fad314cdc4512b4847633489cdd7a25347
2023-10-13 13:53:43 +08:00
hiyouga
171339ab17 update discord link
Former-commit-id: f725cb4940a3a18e9f1edca986ef06d425b39710
2023-10-12 21:44:28 +08:00
hiyouga
8542ba5c69 rename repository
Former-commit-id: 6100ac080a5e52edd66b98147aede6cb77481beb
2023-10-12 21:42:29 +08:00
hiyouga
97b74d328b fix ppo args
Former-commit-id: 0f12899951808f53a482082eb116bda309775930
2023-10-11 23:40:50 +08:00
hiyouga
3198a7e5f4 refactor model_dtype, fix PPO trainer
Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
2023-10-11 23:16:01 +08:00
hiyouga
a2d08ce961 add averaging in evaluation
Former-commit-id: b39d6e0b8658e1c69bbaf6bcb6cfaa8f7af30110
2023-10-10 23:16:31 +08:00
hiyouga
bd8ea09479 fix aquila template, repair sft packing mechanism
Former-commit-id: 8c82cfa5dd4bec957426b5bf176d242c77552ab0
2023-10-10 18:49:55 +08:00
hiyouga
6d0d46c7fb tiny fix
Former-commit-id: 31ccd3329ac634b239c43d60bd955cd95670df16
2023-10-10 17:41:13 +08:00
hiyouga
820540780a update readme
Former-commit-id: 4a9c8a4f18b07455c34e6c1e6bbc81cbefd82eea
2023-10-09 20:02:50 +08:00
hiyouga
f74d600497 fix flash shift short attention
Former-commit-id: e44ad23eafa39b3ac0400b6f97cd440106a87f44
2023-10-09 17:54:48 +08:00
hiyouga
94fec9f50e fix webui args
Former-commit-id: 64aa75c8cd7c84ab4a0f1dbaf4763765ba973f54
2023-10-09 17:13:57 +08:00
hiyouga
e387a50475 fix shift short attention
Former-commit-id: 9a49cce8e6f6b222f74a07bdab40efee6a77b0f1
2023-10-09 17:07:46 +08:00
hiyouga
5c4248a29c update webui #1086
Former-commit-id: 65a48bc398f18f71f5f2659b2070e3b9593af243
2023-10-09 14:50:14 +08:00
hiyouga
f22886e2b6 fix #1097
Former-commit-id: c5b8796322d9d48e815038f9fecf0ce39036a4ee
2023-10-08 22:29:26 +08:00
hiyouga
33af3cbf37 add llamafy_qwen.py
Former-commit-id: 6cdc91543c022edcc98076488f06e809fde9bad7
2023-10-08 22:05:36 +08:00
hiyouga
728dfb1be7 fix #1068 #1074
Former-commit-id: 26c6bfd21de06cc56be9a58e2ef69045ea70cc14
2023-09-28 14:39:16 +08:00
hiyouga
e49f7f1afe fix bug in packed sft dataset
Former-commit-id: 51d26b2af6612e65a91c576da5270028da27b322
2023-09-28 01:16:46 +08:00
hiyouga
21a454fa6c tiny fix
Former-commit-id: 35b355b76d2a8f8adf3750a905224e52d03d218f
2023-09-28 01:03:04 +08:00
hiyouga
22c6c27f78 tiny fix
Former-commit-id: 7451b2ae7e58d0f1857f01a037672a8c53b1bd0d
2023-09-28 01:02:11 +08:00
hiyouga
aecbb43096 fix #1064
Former-commit-id: fd4660aa72d981d7efdad465f24a59358626c975
2023-09-28 00:53:29 +08:00
hiyouga
fa53fd2db2 fix bug in pretraining
Former-commit-id: 18a2d90bd6e7c3e1e3513e6f9d895e4048b35b04
2023-09-28 00:45:20 +08:00
hiyouga
1c150995ae fix layer norm dtype
Former-commit-id: 67af21961b68d9b54d07b09e444c7140869f26da
2023-09-28 00:25:55 +08:00
hiyouga
6c5d8f089e fix #1026
Former-commit-id: d0940d0dbd03d4bbcc955304566b0d5507edf9e6
2023-09-27 22:57:09 +08:00
hiyouga
dd623325e8 fix #424
Former-commit-id: daaf89f1126112a73b9f115b0f5617a8cd974a3e
2023-09-27 22:49:43 +08:00
hiyouga
e8a375c8f2 fix #1032
Former-commit-id: 1235b2da5a79ffefd1342054ea8e7dabf47398c1
2023-09-27 22:42:16 +08:00
hiyouga
386d85ae72 refactor finetuning Args
Former-commit-id: be425a70a4c8f051717cf1e4464dbd79dae4c0b5
2023-09-27 22:28:06 +08:00
hiyouga
ebb3901b05 update readme
Former-commit-id: badbc210435d92cea8799bcd1af4c738da902cd7
2023-09-27 21:57:47 +08:00
hiyouga
20130b486c support LongLoRA
Former-commit-id: 0832ed37e7947d699f17375648a52f80752c2b6b
2023-09-27 21:55:50 +08:00
hiyouga
73c48d0463 add CMMLU, update eval script
Former-commit-id: 47f31f06a946eefa5a972e4a566cf3ce05e1e111
2023-09-23 21:10:17 +08:00
hiyouga
f7cecd20e3 update evaluate
Former-commit-id: 288137a76ed1528faa39b467da22f6468ba368ee
2023-09-23 11:55:31 +08:00
hiyouga
2bc64a7636 move file
Former-commit-id: 8711ca9b5421f971ee4cb2fada23832f1021577c
2023-09-23 11:52:12 +08:00
hiyouga
9564ddbb48 shuffle few shot examples
Former-commit-id: 2c9c14c122382e640dfa41a3799628c764c99457
2023-09-23 00:53:20 +08:00
hiyouga
28062c71b5 fix MMLU
Former-commit-id: eeab92323899694010469451b8dfb1f00d685bff
2023-09-23 00:42:23 +08:00
hiyouga
35d1921081 add MMLU and C-Eval script
Former-commit-id: 3403f876127b4b99c5e3edb2834cc3b9a3a0063f
2023-09-23 00:34:17 +08:00
hiyouga
4fbdf18c70 fix #1000
Former-commit-id: 85de2d0a99e4a81fae890a963ccbb5c6142d52d4
2023-09-22 15:00:48 +08:00
hiyouga
5e07ab01f0 update readme
Former-commit-id: 776f9ea3a5837cb3f80ebe53f19e9951400bf05d
2023-09-22 14:34:13 +08:00
hiyouga
fac465a21e fix webui
Former-commit-id: e28485b476816c1bd6c34f7ff9efaa9e3fb85176
2023-09-21 19:55:38 +08:00
hiyouga
e145a2ce0c tiny fix
Former-commit-id: d24ea58c1a44b94227f4cb60f13fc1dd79997d01
2023-09-21 19:52:06 +08:00
hiyouga
dc68c313ee fix #944
Former-commit-id: 032245647848aaa4167086636b6c985268c5fee3
2023-09-21 19:51:02 +08:00
hiyouga
95c0d9ab24 tiny fix
Former-commit-id: 1a7ddd8c1d20dc251f53923bd0ab9f3f1031dd21
2023-09-21 15:25:29 +08:00
hoshi-hiyouga
46a718f339 Merge pull request #975 from statelesshz/npu-support
Add Ascend NPU support

Former-commit-id: b348c7569c0d3f46b03fb274226444ac7a80e68d
2023-09-20 14:56:50 +08:00
statelesshz
496ba46960 support export model on Ascend NPU
Former-commit-id: 50f94e6d9d62c848db7a3db85fa999d67ddd9f04
2023-09-20 10:26:02 +08:00
hiyouga
43ae0aca1d fix webui
Former-commit-id: 2aa06a5a74d98ec25ed6e1e39df11230670f5bad
2023-09-19 18:35:21 +08:00
hiyouga
b8574c1b82 fix error info
Former-commit-id: b90ed220c5e94086d2b73045eff2440ff1b58c5c
2023-09-19 18:30:23 +08:00
hiyouga
32f8b1082b add tests.cal_flops.py
Former-commit-id: 47a119db6c6e937f6ed96f70e3cda6031b9fbd0d
2023-09-16 23:40:41 +08:00
hiyouga
6443fef31a update readme
Former-commit-id: 813c2df5dc179d82c6c999f63c2640e7c3f6aaff
2023-09-16 17:33:01 +08:00
hiyouga
14c3795a7d fix #913
Former-commit-id: d67c11d69277292648dd9889a7321345e2c0c437
2023-09-15 20:58:28 +08:00
hiyouga
3d9e2de573 fix #896
Former-commit-id: 4b70d623d817460de4732749110622e4a1b51958
2023-09-14 18:37:34 +08:00
hiyouga
0ca36a0f8d fix #887
Former-commit-id: e131bc03e05ccae3c6ad8bb42ccf2cdcc2cf3cea
2023-09-14 17:56:58 +08:00
mmbwf
3e5555502a Update utils.py
Fix parameters load error.

Former-commit-id: 112850364c7fdb53e3a38d42861404fc519108ce
2023-09-14 15:38:04 +08:00
hiyouga
fbf5b5e0a9 add MathInstruct dataset
Former-commit-id: 3d1d4b47055739854cf9788a902607e1bbba3723
2023-09-13 22:30:14 +08:00
hiyouga
3305e66f8c fix ppo save model
Former-commit-id: 300ca6d904524f46cb520056e1319a1e9a13d169
2023-09-12 16:25:29 +08:00
hiyouga
e19a44c12b fix #762 #814
Former-commit-id: 9a30ee5009040afbc524dbac0dad99904b2adf5f
2023-09-12 16:10:10 +08:00
hiyouga
8b0e6b9d1b tiny fix
Former-commit-id: d8ea0691f84c971e6860526714fc9873c350b064
2023-09-11 18:27:08 +08:00
hiyouga
f3e638ac6a Release v0.1.8
Former-commit-id: d9666411375964d334d0a93ec162b27e05f70d49
2023-09-11 17:31:34 +08:00
hiyouga
42e0b30476 update flashattn, fix ppo save model
Former-commit-id: 0b08bc3dac246d4aa3f89afb7172529dcad9c39f
2023-09-11 17:25:36 +08:00
hiyouga
a09a7b650d remove PeftTrainer
Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
2023-09-10 22:23:23 +08:00
hiyouga
332d7bbd56 truncate readme
Former-commit-id: fed5d0cc87e4a5a023f2edae622f2820bded1509
2023-09-10 21:04:20 +08:00
hiyouga
d3b6fece71 update readme
Former-commit-id: c42fe77fec2918fe8811d48ec88e9a7c1e6f07ab
2023-09-10 21:01:20 +08:00
hiyouga
9d963b82de update readme
Former-commit-id: b4109cfe548e091cd20fa84815dce5ff3974a090
2023-09-10 20:52:21 +08:00
hiyouga
a402161631 support FlashAttention2
Former-commit-id: 23e56c5554b948d4f08ad87849b261eafd2c7890
2023-09-10 20:43:56 +08:00
hiyouga
b481ad58e6 fix #850
Former-commit-id: e5975c4c6b8bd47ec506b0d4a4703bee05495436
2023-09-10 14:22:03 +08:00
hiyouga
f91c5f2638 fix lora target
Former-commit-id: d822e41e7ac7e310ee49e347fc45754284ce30b8
2023-09-09 17:04:45 +08:00
hiyouga
7143c551ab support lora target auto find
Former-commit-id: bce9984733d88bf013847eed523d1c75fdf0995e
2023-09-09 15:38:37 +08:00
hiyouga
50e93392dd fix chatglm2 tokenizer
Former-commit-id: 1ab60b4a93fa1be5dfe6ffbd4deb64c0f9d9b431
2023-09-09 13:50:29 +08:00
hiyouga
9f83e93839 add baichuan2 convert script
Former-commit-id: 4d676e0ea9e59c1be13ecb47734917ba78938ac8
2023-09-08 22:59:41 +08:00
hiyouga
692b132dbf fix bug in DPO data collator
Former-commit-id: 4fc262cdf1347691e253bdfbd96568db5a49c086
2023-09-08 20:45:07 +08:00
hiyouga
e70b3e8947 fix #761
Former-commit-id: be76f6cbe5143f781b6b39603b80392253b3080a
2023-09-08 20:22:18 +08:00
hiyouga
612d97db6f change to right-padding, update reward score #803
Former-commit-id: baa90415bc8f5ebd423d001378b51c3a3a6c2ec7
2023-09-08 20:04:31 +08:00
hiyouga
bb1b67c076 fix chatglm template
Former-commit-id: 69a824628b4d6a56a680a7e713b217877c6c15c5
2023-09-08 14:45:58 +08:00
hiyouga
5a75c31caa update requirements
Former-commit-id: d796a4a5709c390629bafbeb7c91fccf6a9076d0
2023-09-07 19:26:25 +08:00
hiyouga
8b9210286b fix #818
Former-commit-id: e81fd458c279ed2f3cee780e517482b425c8886d
2023-09-07 19:19:53 +08:00
hiyouga
b5acec34f7 add deepspeed check in PPO training
Former-commit-id: e203ec7f71f504ccbaa89c27d20b8a0d9fa53f7e
2023-09-07 19:12:40 +08:00
hiyouga
86d835878c fix #809
Former-commit-id: 2783ca75365d7c373cefba039788a48f0b8f35fc
2023-09-07 19:04:32 +08:00
hiyouga
eae7b331d3 fix baichuan templates
Former-commit-id: f48a49e835b32f3991cfad8874c7b9c78953809f
2023-09-07 18:54:14 +08:00
hiyouga
ed89e29bcc update baichuan2 template
Former-commit-id: 16d9f8ba176443c5b397233da621600d6e1e1eec
2023-09-06 21:43:06 +08:00
hiyouga
c2b1886aff add Baichuan2 models
Former-commit-id: 90b3f02c44c0b8cc1b59f37af3a1ec28874a8a61
2023-09-06 18:40:11 +08:00
hiyouga
218f36bca5 add Baichuan2 models
Former-commit-id: 36960025e9274b574f57e7a7bf453cd96956e922
2023-09-06 18:36:04 +08:00
hoshi-hiyouga
b91fc1f5b3 Merge pull request #786 from kinghuin/patch-1
fix utils.py bug

Former-commit-id: 26aad616340748e1594a60119ca9434908bf7465
2023-09-05 10:49:34 +08:00
Q
2a22bf9c15 fix utils.py bug
Former-commit-id: dc490117d50c3cbc070b804bac89400f4290272f
2023-09-05 10:38:01 +08:00
hiyouga
62e2037125 fix #763
Former-commit-id: e424b928a35097b783af879a2290f59b2158801d
2023-09-01 23:13:05 +08:00
hiyouga
e5b72c6a77 refactor dataset_attr, add eos in pt, fix #757
Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
2023-09-01 19:00:45 +08:00
codingma
93be211f80 Merge pull request #741 from hiyouga/feature-addDatasetCheck
Feature add dataset check

Former-commit-id: 4b6dabe73d2c7edc94cd495390577c8bcf88428b
2023-08-31 20:57:36 +08:00
codemayq
9ae3fb4ced update llama2 template
Former-commit-id: 01de1d51d9fa5a22a338b6ed18ffad4d0ad5e3e8
2023-08-30 16:23:56 +08:00
codemayq
f641075789 add dataset stage check
Former-commit-id: 5c719a7ce988339d034a653456da9742dc2cec7c
2023-08-30 16:23:08 +08:00
codingma
f7658db1b6 Merge pull request #651 from hiyouga/feature-dataset_stage
add dataset stage

Former-commit-id: 3b0ef57405cbc22ff8ce4eef2cfcb73872519db5
2023-08-28 16:03:45 +08:00
codemayq
b869bc1a20 add ad gen dataset
Former-commit-id: fcd0788aa4dda0cecc1420d369d371032a207810
2023-08-27 20:35:32 +08:00
codemayq
a72d756d77 add text format dataset preview in webui
Former-commit-id: cd30871aadb40cd3d598a6d0b415946744d2d550
2023-08-24 19:45:36 +08:00
codemayq
d3fd8f89b8 add stage in DatasetAttr
Former-commit-id: 9c55200d8de0623640f529dbf39b8b0f169636d3
2023-08-23 20:54:53 +08:00
hiyouga
180a05a446 fix import error
Former-commit-id: b3207a974a45038591b8cbbcf20d1ca1142d6679
2023-08-23 20:45:03 +08:00
hiyouga
eb9ac9ee1f fix #649
Former-commit-id: e6120a937ddb4f3c0b9bcb2466742f5cf4f77f8c
2023-08-23 20:21:15 +08:00
codemayq
a6662b73f5 add readme for dataset
Former-commit-id: bdcb0ea40e726e4c5752f938b379ed9a18e7e1d0
2023-08-23 19:55:45 +08:00
codemayq
cbc7db3478 add dataset stage and filter dataset when stage chosen in webui
Former-commit-id: 26e4136449a4df6028d834fd16a0f4a7c532759d
2023-08-23 18:54:23 +08:00
hiyouga
4606340f0f fix webui
Former-commit-id: 95304b6822d9fe04bcddc1ee246a56389bd5f96a
2023-08-23 11:03:35 +08:00
hoshi-hiyouga
d4b4ccd597 Merge pull request #644 from hiyouga/fix-quantization_bit
fix quantization bit is ""

Former-commit-id: e1a8eca182e532b48e472919b4474656a726b40c
2023-08-23 10:45:45 +08:00
codemayq
9c3f4e3a37 fix quantization bit is ""
Former-commit-id: 0dcab66f8843e2887f9f7ca66334122fef35c5b7
2023-08-23 10:08:17 +08:00
codemayq
440e00d8f9 fix quantization is ""
Former-commit-id: 2469cc16d1dd3f5ee822edc18b2d7021ff7cba03
2023-08-23 10:04:03 +08:00
hiyouga
6310613699 update template
Former-commit-id: a95f3a4d62de1073a78125401cf4289ec0523156
2023-08-22 19:46:09 +08:00
hoshi-hiyouga
f55907dbea Merge pull request #629 from panpan0000/main
add rm dataset explanation

Former-commit-id: c2b4571d0ffb6298d6e07212982d9c13efd65adf
2023-08-22 13:41:44 +08:00
Peter Pan
5cac87d317 add rm dataset explanation
Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>

Former-commit-id: 1efb95025be6501f1b30b20e7c711d3590b5d1ee
2023-08-22 01:33:59 -04:00
hoshi-hiyouga
9c0622de13 Merge pull request #619 from hiyouga/feature-templateTest
add template encode test

Former-commit-id: 8a1587ae49fff3968e0182f4fcc9a65dfdb260fc
2023-08-21 20:56:34 +08:00
codemayq
37b93c8b71 add template encode test
Former-commit-id: c15e0d6847cbc055d8376b3c43ac4fbd17b5877a
2023-08-21 20:51:24 +08:00
hiyouga
d6be98cda6 fix #617
Former-commit-id: a7bdaf1c92c7d798caf8438dc42a8972632ec584
2023-08-21 18:16:11 +08:00
hiyouga
4d128acc17 fix #608
Former-commit-id: c02a6809124fcfd06628c49c95d419ec2d8cc8ef
2023-08-21 17:49:36 +08:00
hiyouga
516df9ecce fix baichuan template for training #597 #616
Former-commit-id: 6530c1d972301eac9ef058b3235618bb09833f15
2023-08-21 17:41:51 +08:00
hiyouga
8eec1d50e1 fix #595
Former-commit-id: a360ccf9aa0484ce783eaa5857cf698b3ac2051e
2023-08-20 16:40:00 +08:00
hoshi-hiyouga
cfb096d43a Merge pull request #596 from beat4ocean/beat
fix KeyError: 'lang' bug

Former-commit-id: dd22541cdf1b832d20bb894d78c034afce841bfb
2023-08-20 16:37:40 +08:00
beat4ocean
713fa28804 fix KeyError: 'lang' bug
Former-commit-id: 4d4d9172b1f362cb4876315f1f5739e417055065
2023-08-20 15:32:36 +08:00
hiyouga
5549f35939 fix ppo trainer #551
Former-commit-id: 050a5447c191b8c50a0826a0f03bae499bff8b48
2023-08-20 14:07:11 +08:00
hiyouga
6eed1db36c Release v0.1.7
Former-commit-id: 81abe8d6cabaa1ebe74dc32a5dc143389e4c9f31
2023-08-18 17:21:27 +08:00
hiyouga
948124f55e tiny fix
Former-commit-id: 0ee159654ac6339c162745b004e2152ba6fe3c81
2023-08-18 13:07:35 +08:00
hiyouga
2b191ca776 support ppo score norm (trl 0.5.1.dev required)
Former-commit-id: 2b25db6d260ec1532281a592e873579346c7d21c
2023-08-18 12:02:42 +08:00
hiyouga
be4d2822ea fix PPO trainer #551 , update readme
Former-commit-id: faead74849470cebae9e37cde5fab2a71b32aa43
2023-08-18 11:43:10 +08:00
hiyouga
736ddd0319 update readme
Former-commit-id: beaf2fb737dbe64d35334d88b42935c89ef09eee
2023-08-18 01:51:55 +08:00
hiyouga
dfa289aa72 Update .gitignore
Former-commit-id: a1772a4dfef8dfaf7c2c321fad0a70ccf95fe6a0
2023-08-18 01:43:42 +08:00
hiyouga
c2644f939a update training resuming
Former-commit-id: 2ec75c31f609e65116ac3b621eeb7d8ccbf69135
2023-08-18 01:41:17 +08:00
hoshi-hiyouga
f11c1ae562 Merge pull request #434 from niuba/main
add last_checkpoint support

Former-commit-id: b78d461f2826c194c332ead37825704c2cb8b910
2023-08-18 01:38:31 +08:00
hoshi-hiyouga
3126164aa6 Merge branch 'main' into main
Former-commit-id: 870d2c7bf74d0da5a927bef4b8b01d15cc66a3e9
2023-08-18 01:37:23 +08:00
hiyouga
ed10486cad support bf16 ppo #551
Former-commit-id: 092088967de7409a2d51847cfc7afc83a8887320
2023-08-18 00:40:32 +08:00
hiyouga
04fa430c6c fix ChatGLM2 ppo #527 #528
Former-commit-id: 60d6ad64d7c9f6445b0df8de0153c3a311974198
2023-08-18 00:34:59 +08:00
hiyouga
fa1893b59c fix generation bug #532
Former-commit-id: c071121e67374e5f09798db57cfc8668617a36ae
2023-08-17 22:21:34 +08:00
hiyouga
e993e717a5 fix streaming in pt stage #548 #549
Former-commit-id: 050e992bee2a9293cc7399b578de807b5bf9bddc
2023-08-17 17:59:26 +08:00
hiyouga
c80e56423a update readme
Former-commit-id: b74af3c9cf29e1690ae4d5acb27599b1abd152e2
2023-08-17 11:00:22 +08:00
hiyouga
ffa09a01d6 fix baichuan and intern template
Former-commit-id: e1fd18fa6ef1009f978aca5210a259251a0b19a6
2023-08-17 01:27:20 +08:00
hiyouga
7d04f8567b fix generation
Former-commit-id: 66a0300d312ef91c24fcf80667fa3b0bb8e1a342
2023-08-16 22:39:54 +08:00
hiyouga
baa709674f fix system prompt
Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
2023-08-16 01:35:52 +08:00
hiyouga
ca9a494d0c fix baichuan template #481
Former-commit-id: 7608c6c25877d97ef26a1c209c4073c9c42f4535
2023-08-15 11:38:21 +08:00
hoshi-hiyouga
37eb8c05cc Merge pull request #516 from liuyanyi/add_gitignore
[Enhance] Add .gitignore file

Former-commit-id: 12cfe5482f5ef95d8c386d0af0de381e72eab0f9
2023-08-15 11:25:40 +08:00
hiyouga
7c046edb7b fix ChatGLM RLHF
Former-commit-id: 4e43e887e432ceb7e9287b4e309b63af3c3ba1bf
2023-08-15 11:19:20 +08:00
Yanyi Liu
22cea38b20 Add .gitignore
Former-commit-id: a2ebdeef81706596617da4409fc5da71739bccdc
2023-08-15 11:13:45 +08:00
hiyouga
ef2ca0a827 alert pad_token source
Former-commit-id: f26a84e0d927d2554890daf431a93652e18f4235
2023-08-15 00:07:56 +08:00
hiyouga
7f0b908de2 update webui
Former-commit-id: da30d0fb4abdb825f3383ddd106bb06a84695b7a
2023-08-14 22:45:26 +08:00
hoshi-hiyouga
5fc5e776ff Merge pull request #511 from hiyouga/feature-autoTemplate
add template match and stage in webui

Former-commit-id: 413752ecba845cddaff5fb48db7d3d24b960eec1
2023-08-14 22:44:04 +08:00
codemayq
93b281c016 auto match template when change model_name
Former-commit-id: ab2d7ab0572765ce33a52ac71641062d5d904db4
2023-08-14 20:56:05 +08:00
codemayq
9585699918 add template match and stage in webui
Former-commit-id: d6283e7f041f08f76d18350cb5f6a6c58ca80e92
2023-08-14 20:42:59 +08:00
hiyouga
bceaba551d fix ChatGLM lm_head #494
Former-commit-id: bf0048abdaeb2b9592d38ac991704ad014370b47
2023-08-14 14:14:48 +08:00
hiyouga
0bfeed3a7e fix bug in webui
Former-commit-id: c95f0f687689934379b6c24abf872ffcde06073b
2023-08-14 11:38:42 +08:00
hiyouga
70a780c3c0 fix webui cache
Former-commit-id: 9aba5c197fbc8abaab77f454374f8b497f0310d0
2023-08-14 11:37:01 +08:00
hiyouga
d74ab5306c update readme_zh
Former-commit-id: bdfe7e0285fdeb3a2728669dbdabf70c9652735c
2023-08-14 11:13:25 +08:00
hiyouga
688e8601ab web UI integrating RLHF
Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
2023-08-14 10:48:47 +08:00
hiyouga
4933ab5956 fix #480
Former-commit-id: ec15ca8fffacba2c34e1849c5ce90ca9989d66a2
2023-08-14 00:23:56 +08:00
hiyouga
6c7225a5d4 fix webui
Former-commit-id: 2c8b7414be9b43e20cc1d0575cc4dc1c7545fd86
2023-08-12 23:52:07 +08:00
hiyouga
a22982f2fa tiny fix
Former-commit-id: 50a34c043de6d9e1410291e1d8c1ea9d53754e9e
2023-08-12 22:02:43 +08:00
hiyouga
c95479dddb fix rope scaling
Former-commit-id: 2e0dd36700ec5e8294581c1db4b9431f755fc5f8
2023-08-12 22:00:01 +08:00
hiyouga
fc48bd8da0 update readme
Former-commit-id: 94ac570cb62aa9cd5dba105f0bb4c4da43eca042
2023-08-12 21:29:06 +08:00
hiyouga
d5323bfa3f update readme
Former-commit-id: ecfe87f34b383901f8e97ffb90af459cd55419b1
2023-08-12 21:25:19 +08:00
hiyouga
e9d4a2b507 update readme
Former-commit-id: eadbe9b7a0b6c8897e7a763b519cc5b7e00f3b2c
2023-08-12 21:23:05 +08:00
hiyouga
37bcbe8046 update readme
Former-commit-id: 6fa381400c21fa249cebcdff8c3afd72f8de20b3
2023-08-12 21:00:11 +08:00
hiyouga
fdfb644f0a support rope scaling, fix #475 #476 #478
Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
2023-08-12 20:46:27 +08:00
hoshi-hiyouga
cde9f3db57 Merge pull request #479 from hiyouga/feature-addCmdExport
add sft script preview in webui

Former-commit-id: 060225e57d13d8164beb6920410c181fbb28b77a
2023-08-12 20:41:52 +08:00
codemayq
8bf5a98815 add sft script preview in webui
Former-commit-id: 2b72649b404750226aa418b61ef5a6c9ac03938f
2023-08-12 13:53:55 +08:00
hiyouga
be566a15a5 fix unusual output of 8bit models #278 #391
Former-commit-id: 337ce5272b81f5561162beb08814b0e5abf23703
2023-08-12 00:25:29 +08:00
hiyouga
d5f1b99ac4 Release v0.1.6
Former-commit-id: 43c8b3c3c8bfb2e32d17fb3e8b194938e37d54bd
2023-08-11 23:25:57 +08:00
hiyouga
2144bb0e27 Update README_zh.md
Former-commit-id: 4fc154bcf039ba3f9240213158df757881cf3579
2023-08-11 14:06:02 +08:00
hiyouga
bc665bacc7 add defaults
Former-commit-id: 4636d3bbe6b984ca93e3a80ae5239f3ddda461bd
2023-08-11 13:56:26 +08:00
hiyouga
52bfcf4883 fix stop word in baichuan template
Former-commit-id: cba5ac9cfc81f11b97831998ea15def5e0b487c2
2023-08-11 13:51:46 +08:00
hiyouga
06df3d6fb6 fix baichuan template
Former-commit-id: b1681fe35346381cda613297f1cbb710f0a6daa6
2023-08-11 13:45:47 +08:00
hiyouga
ca719a8697 support DPO training (2305.18290)
Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
2023-08-11 03:02:53 +08:00
hoshi-hiyouga
72dfd74005 Merge pull request #451 from jovialchen/main
huggingface login for projects must login while running

Former-commit-id: 246ac241277908909b81cdf85fec1f24449dbae9
2023-08-10 17:25:38 +08:00
hiyouga
69302c4420 fix webui val size
Former-commit-id: 490c067d4e0828832e0ebdb704a9207dc974b15b
2023-08-10 15:20:44 +08:00
jiongxuc
42d7019b2e huggingface login for projects must login while running
Former-commit-id: 0a4a2a1d3e0ff1f57215512d294d782080bd383c
2023-08-10 14:57:12 +08:00
hiyouga
5f0d0d6b9b fix template
Former-commit-id: e3967eb1cdd8d19e8afee9ba52e7eb7d6cd86129
2023-08-09 23:14:27 +08:00
hiyouga
76cb63e4f6 fix template
Former-commit-id: 907e8cd86fbd4cdfa26dad21ceaf6e01d8fe37e4
2023-08-09 23:10:20 +08:00
hiyouga
467d571206 support val set in streaming mode
Former-commit-id: faed15b58ed00b1e09bb091e7eee48f5ef7c508b
2023-08-09 23:00:26 +08:00
hiyouga
972bfa700a fix tokenizer
Former-commit-id: 7849587cd4e149291d08edef9a528a1bad796c7e
2023-08-09 17:52:15 +08:00
niuba
458955d0fb add last_checkpoint support
Former-commit-id: 9f1977e4de00b14a9d1b555c25bcaf12998d5046
2023-08-09 16:39:27 +08:00
hiyouga
990eeccf45 fix sft trainer
Former-commit-id: 08cc888b1569572d0cd20bcf3f07e20072a0311a
2023-08-09 16:35:03 +08:00
hiyouga
a3a7465f00 fix rm #420, fix template #426, fix #423
Former-commit-id: 70ea3caaa7a7695c77179cd1bb18707a80a373d7
2023-08-09 16:23:31 +08:00
hoshi-hiyouga
031a819257 fix llama2 template
Former-commit-id: 6c74f726d4e672f5a1a57df201c27c1f697384f0
2023-08-09 00:58:27 +08:00
hoshi-hiyouga
eb4b4e3c8c fix tokenizer
Former-commit-id: fa463ef279b596d5d53cc169831f51b42031fc05
2023-08-09 00:54:54 +08:00
hiyouga
d2e1fe9b1d update webui
Former-commit-id: 343a4cd82b07a40f96ba413d1d991419ff07a24a
2023-08-09 00:26:11 +08:00
hiyouga
6e27a9e39a fix tokenizer #417
Former-commit-id: 01aa678311bfd213a4b410a4e0ff09f48a0d40a1
2023-08-08 23:59:41 +08:00
hiyouga
805478c911 fix bug
Former-commit-id: 0dff1d951f1a9fe05a74d334bf477b55c7c64199
2023-08-08 21:28:28 +08:00
hiyouga
a281cdeb89 fix bug
Former-commit-id: c13ce66021b21e015871b84489eeafa127a424a4
2023-08-08 17:55:55 +08:00
hiyouga
cda698a67f fix chatml template #408
Former-commit-id: 21e0cc3f44c35ae689b00b274391492f413725ac
2023-08-08 17:44:39 +08:00
hiyouga
15acd17716 update args spec
Former-commit-id: a006068346edda6e2851b23d2005fdb218a7287d
2023-08-07 15:23:35 +08:00
hiyouga
34a2bddfcd update readme
Former-commit-id: 06bcbb901f69265632892a5fcbc956b8be1153da
2023-08-07 15:02:02 +08:00
hiyouga
370f817549 Merge branch 'main' of https://github.com/hiyouga/LLaMA-Efficient-Tuning
Former-commit-id: 5c5657227db285048e3850631badb040eea9b6ca
2023-08-07 13:59:16 +08:00
hiyouga
041390c37e fix #376
Former-commit-id: a5b01257ba3323bcb2dd0217fb89a387e39ddbec
2023-08-07 13:58:59 +08:00
hoshi-hiyouga
d9fe4bf500 Merge pull request #382 from hiyouga/feature-updateReadme
add detailed model configs

Former-commit-id: 371c50cf3fd4e3f5e8fb390508c27cb5f18fa531
2023-08-07 13:43:38 +08:00
hiyouga
e0c7e944fc update trainer
Former-commit-id: 0d39b53a5164e34d22fe0a492eaa0d7ac63102fe
2023-08-07 13:34:35 +08:00
codemayq
0845fe67db add detailed model configs
Former-commit-id: 438c43f820e39738eaa1c296aadcf6d141c3289f
2023-08-07 09:30:23 +08:00
hiyouga
fe3b12d900 fix qwen eos token
Former-commit-id: 770830c67886f5872b39b9608949ec62d4616b27
2023-08-06 13:31:17 +08:00
hiyouga
a70d56864e fix qwen tokenizer #361
Former-commit-id: 78a2fa95c8ab669254a6c8fce8138c4395fb0a09
2023-08-05 17:06:05 +08:00
hiyouga
fdbb2c5378 fix template for tiktoken
Former-commit-id: 8328447f81eb5b90310df08cf2928c83ef6355fe
2023-08-05 13:42:42 +08:00
hiyouga
3c0aaf42af remove redundant code
Former-commit-id: dcec1717592107ba9d26eb2ac520309da19d1805
2023-08-05 00:27:27 +08:00
hiyouga
438e19160a fix template
Former-commit-id: b88200a88ea112e043dc44058606805c60e32844
2023-08-05 00:25:00 +08:00
hiyouga
f2b2ff6950 fix llama2 template
Former-commit-id: 08f37145e0bca5f1a8fd7bad01c64dc69b07361b
2023-08-05 00:07:54 +08:00
hoshi-hiyouga
86cef96305 Support safe ChatML template, fix qwen tok #351 #354
https://github.com/openai/openai-python/blob/main/chatml.md
Former-commit-id: 94bfc9d85f7cef3a5eb15085e0124a424373814f
2023-08-05 00:00:23 +08:00
hiyouga
5f50944baf fix bos and eos token
Former-commit-id: ab386f4c0fb5eaac24264a5bbef4c03deeb92158
2023-08-04 23:55:57 +08:00
hiyouga
0804fd2353 fix encode
Former-commit-id: ec382abd906d93cf78c7fbaec753ce6bcf8cfebd
2023-08-04 23:27:55 +08:00
hiyouga
86419eb457 support chatml safe encoding
Former-commit-id: ea52bb135bf9d07738091006ec7ada8df14cf15e
2023-08-04 23:14:28 +08:00
hiyouga
76f3ae7bf3 support interleave probs
Former-commit-id: 168d99816f9bdc746c587f7f09753ba7e0a4b19d
2023-08-04 21:27:35 +08:00
hiyouga
aaa85190eb fix webui export model
Former-commit-id: c34469c05e681239db23e2e666b5ac6a4e38aba9
2023-08-04 14:20:27 +08:00
hiyouga
e2a4e926b9 fix mtloader
Former-commit-id: ca48c2c02c3cfa9afb99971b50daeda9cf14e7cb
2023-08-03 19:29:02 +08:00
hiyouga
d6e922dc1c tiny fix
Former-commit-id: 81ef7017a4c96441951adeff0276cc5ab76a3544
2023-08-03 17:42:28 +08:00
hiyouga
27f4317ec6 fix qwen inference
Former-commit-id: 823f0de0ca0a92b6f48a90e5ffe57a48dc018f1d
2023-08-03 16:31:55 +08:00
hiyouga
e434348216 fix qwen inference
Former-commit-id: 2c5fe45ce1405124f12ecd20e263b5538af97972
2023-08-03 16:15:38 +08:00
hiyouga
2e19afedb8 support Qwen-7B, fix InternLM-7B inference
Former-commit-id: 25d2ca29ecb70cbfd5206333c667042a0c4d2e5a
2023-08-03 15:53:32 +08:00
hiyouga
da08fa7c63 update web demo
Former-commit-id: 5b6ad9adb665096bfb36dc90789a1d4a16345122
2023-08-03 13:28:28 +08:00
hiyouga
9c96b97dc7 fix webui
Former-commit-id: e87630ef77977b2879f1199b9a421acbbbb32a51
2023-08-03 12:43:12 +08:00
hiyouga
28a51b622b modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
2023-08-02 23:17:36 +08:00
hiyouga
8bd1da7144 fix PPO trainer
Former-commit-id: 21982a7d4dd9b7c3a1145b481f02b9990e32dc00
2023-08-02 19:10:23 +08:00
hiyouga
e4d0b8ee6e update ppo trainer
Former-commit-id: c27136a83e167465d3f825e40f10c7b9fcfbf97a
2023-08-02 18:46:41 +08:00
hiyouga
1dfb28b362 fix memory leak of PPO trainer
Former-commit-id: 38410894a5ebf0b043b55a6bd5cca3cd0a44b27d
2023-08-02 17:41:34 +08:00
151 changed files with 12476 additions and 4862 deletions

58
.github/ISSUE_TEMPLATE/bug-report.yml vendored Normal file
View File

@@ -0,0 +1,58 @@
name: "\U0001F41B Bug / Help"
description: Create a report to help us improve the LLaMA Factory
body:
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the README carefully and searched the existing issues.
请确保您已经认真阅读了 README 并且搜索过现有的 Issue。
options:
- label: I have read the README and searched the existing issues.
required: true
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide code snippets, error messages and stack traces that reproduces the problem.
请提供运行参数,错误信息以及异常堆栈以便于我们复现该问题。
Remember to use Markdown tags to correctly format your code.
请合理使用 Markdown 标签来格式化您的文本。
placeholder: |
python src/train_bash.py ...
- type: textarea
id: expected-behavior
validations:
required: false
attributes:
label: Expected behavior
description: |
Please provide a clear and concise description of what you would expect to happen.
请提供您原本的目的,即这段代码的期望行为。
- type: textarea
id: system-info
validations:
required: false
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。
placeholder: transformers version, platform, python version, ...
- type: textarea
id: others
validations:
required: false
attributes:
label: Others

29
.github/workflows/tests.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: tests
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
check_code_quality:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install black ruff
- name: Check quality
run: |
make style && make quality

165
.gitignore vendored Normal file
View File

@@ -0,0 +1,165 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# custom .gitignore
user.config
saves/
cache/

128
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
`hoshihiyouga AT gmail DOT com`.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

11
Makefile Normal file
View File

@@ -0,0 +1,11 @@
.PHONY: quality style
check_dirs := src tests
quality:
black --check $(check_dirs)
ruff $(check_dirs)
style:
black $(check_dirs)
ruff $(check_dirs) --fix

593
README.md
View File

@@ -1,102 +1,218 @@
# LLaMA Efficient Tuning ![# LLaMA Factory](assets/logo.png)
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Efficient-Tuning?style=social)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers) [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Efficient-Tuning)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20In%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
👋 Join our [WeChat](assets/wechat.jpg). 👋 Join our [WeChat](assets/wechat.jpg).
\[ English | [中文](README_zh.md) \] \[ English | [中文](README_zh.md) \]
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**.
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
## Table of Contents
- [Benchmark](#benchmark)
- [Changelog](#changelog)
- [Supported Models](#supported-models)
- [Supported Training Approaches](#supported-training-approaches)
- [Provided Datasets](#provided-datasets)
- [Requirement](#requirement)
- [Getting Started](#getting-started)
- [Projects using LLaMA Factory](#projects-using-llama-factory)
- [License](#license)
- [Citation](#citation)
- [Acknowledgement](#acknowledgement)
## Benchmark
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
![benchmark](assets/benchmark.svg)
<details><summary>Definitions</summary>
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
</details>
## Changelog ## Changelog
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details. [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model. [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. <details><summary>Full Changelog</summary>
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model. [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [Hugging Face Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details. [23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**. [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model. [23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models. [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model. [23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
[23/07/31] We supported **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
</details>
## Supported Models ## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) | Model | Model size | Default module | Template |
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B) | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B) | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
- [InternLM](https://github.com/InternLM/InternLM) (7B) | [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
## Supported Training Approaches ## Supported Training Approaches
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
- Full-parameter tuning | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
- Partial-parameter tuning | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [LoRA](https://arxiv.org/abs/2106.09685) | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [QLoRA](https://arxiv.org/abs/2305.14314) | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- Full-parameter tuning | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685) > [!NOTE]
- [QLoRA](https://arxiv.org/abs/2305.14314) > Use `--quantization_bit 4` argument to enable QLoRA.
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## Provided Datasets ## Provided Datasets
- For pre-training: <details><summary>Pre-training datasets</summary>
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - [Wiki Demo (en)](data/wiki_demo.txt)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- For supervised fine-tuning: - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) </details>
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) <details><summary>Supervised fine-tuning datasets</summary>
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - [Self Cognition (zh)](data/self_cognition.json)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- For reward modelling: - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
</details>
<details><summary>Preference datasets</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details>
Please refer to [data/README.md](data/README.md) for details. Please refer to [data/README.md](data/README.md) for details.
@@ -111,28 +227,37 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+ - Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL - 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- jieba, rouge-chinese and nltk (used at evaluation) - sentencepiece, protobuf and tiktoken
- gradio and matplotlib (used in web_demo.py) - jieba, rouge-chinese and nltk (used at evaluation and predict)
- uvicorn, fastapi and sse-starlette (used in api_demo.py) - gradio and matplotlib (used in web UI)
- uvicorn, fastapi and sse-starlette (used in API)
And **powerful GPUs**! ### Hardware Requirement
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
## Getting Started ## Getting Started
### Data Preparation (optional) ### Data Preparation (optional)
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset. Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
Note: please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`. > [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
### Dependence Installation (optional) ### Dependence Installation (optional)
```bash ```bash
git lfs install git clone https://github.com/hiyouga/LLaMA-Factory.git
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_factory python=3.10
conda create -n llama_etuning python=3.10 conda activate llama_factory
conda activate llama_etuning cd LLaMA-Factory
cd LLaMA-Efficient-Tuning
pip install -r requirements.txt pip install -r requirements.txt
``` ```
@@ -142,24 +267,43 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
``` ```
### All-in-one Web UI ### Use ModelScope Hub (optional)
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
``` ```
Currently the web UI only supports training on **a single GPU**. Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
### (Continually) Pre-Training ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path modelscope/Llama-2-7b-ms \
... # arguments (same as above)
```
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```
### Train on a single GPU
> [!IMPORTANT]
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
#### Pre-Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \ --stage pt \
--model_name_or_path path_to_your_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \ --dataset wiki_demo \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \ --output_dir path_to_pt_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -173,16 +317,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### Supervised Fine-Tuning #### Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \ --output_dir path_to_sft_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -196,22 +341,77 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
Remember to specify `--lora_target W_pack` if you are using Baichuan models. #### Reward Modeling
### Reward Model Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \ --stage rm \
--model_name_or_path path_to_your_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \ --dataset comparison_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \ --lora_target q_proj,v_proj \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-6 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--top_k 0 \
--top_p 0.9 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
#### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
@@ -222,50 +422,22 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### PPO Training (RLHF)
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
```
### Distributed Training ### Distributed Training
#### Use Huggingface Accelerate
```bash ```bash
accelerate config # configure the environment accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above) accelerate launch src/train_bash.py # arguments (same as above)
``` ```
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary> <details><summary>Example config for LoRA training</summary>
```yaml ```yaml
compute_environment: LOCAL_MACHINE compute_environment: LOCAL_MACHINE
deepspeed_config: distributed_type: MULTI_GPU
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no' downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
@@ -281,121 +453,170 @@ use_cpu: false
</details> </details>
### Evaluation (BLEU and ROUGE_CHINESE) #### Use DeepSpeed
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--stage sft \ --deepspeed ds_config.json \
--model_name_or_path path_to_your_model \ ... # arguments (same as above)
--do_eval \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
``` ```
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. <details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
### Predict ```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### Merge LoRA weights and export model
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ python src/export_model.py \
--stage sft \ --model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \ --adapter_name_or_path path_to_checkpoint \
--do_predict \
--dataset alpaca_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --export_dir path_to_export \
--output_dir path_to_predict_result \ --export_size 2 \
--per_device_eval_batch_size 8 \ --export_legacy_format False
--max_samples 100 \
--predict_with_generate
``` ```
> [!WARNING]
> Merging LoRA weights into a quantized model is not supported.
> [!TIP]
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model after merging the LoRA weights.
### API Demo ### API Demo
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
Visit `http://localhost:8000/docs` for API documentation. > [!TIP]
> Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo ### CLI Demo
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### Web Demo ### Web Demo
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### Export model ### Evaluation
```bash ```bash
python src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--template default \ --adapter_name_or_path path_to_checkpoint \
--template vanilla \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --task mmlu \
--output_dir path_to_export --split test \
--lang en \
--n_shot 5 \
--batch_size 4
``` ```
## TODO ### Predict
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)). ```bash
- [ ] Implementing multi-query attention for faster inference. CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
- [ ] Supporting full-parameter RLHF training. --stage sft \
--do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate \
--fp16
```
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 predict.
> [!TIP]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
## Projects using LLaMA Factory
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
> [!TIP]
> If you have a project that should be incorporated, please contact via email or create a pull request.
## License ## License
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
## Citation ## Citation
If this work is helpful, please kindly cite as: If this work is helpful, please kindly cite as:
```bibtex ```bibtex
@Misc{llama-efficient-tuning, @Misc{llama-factory,
title = {LLaMA Efficient Tuning}, title = {LLaMA Factory},
author = {hiyouga}, author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}}, howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
year = {2023} year = {2023}
} }
``` ```
## Acknowledgement ## Acknowledgement
This repo is a sibling of [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning). They share a similar code structure of efficient tuning on large language models. This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
## Star History ## Star History
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Efficient-Tuning&type=Date) ![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)

View File

@@ -1,104 +1,220 @@
# LLaMA Efficient Tuning ![# LLaMA Factory](assets/logo.png)
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Efficient-Tuning?style=social)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers) [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Efficient-Tuning)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20In%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
👋 加入我们的[微信群](assets/wechat.jpg)。 👋 加入我们的[微信群](assets/wechat.jpg)。
\[ [English](README.md) | 中文 \] \[ [English](README.md) | 中文 \]
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。该模式目前仅支持单卡训练
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
## 目录
- [性能指标](#性能指标)
- [更新日志](#更新日志)
- [模型](#模型)
- [训练方法](#训练方法)
- [数据集](#数据集)
- [软硬件依赖](#软硬件依赖)
- [如何使用](#如何使用)
- [使用了 LLaMA Factory 的项目](#使用了-llama-factory-的项目)
- [协议](#协议)
- [引用](#引用)
- [致谢](#致谢)
## 性能指标
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
![benchmark](assets/benchmark.svg)
<details><summary>变量定义</summary>
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4截断长度=1024
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4截断长度=1024
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1截断长度=1024
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`
</details>
## 更新日志 ## 更新日志
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集 [24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft) [24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数 [24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 <details><summary>展开日志</summary>
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数 [23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目 [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数 [23/12/01] 我们支持了**[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数 [23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune例如 `--neftune_noise_alpha 5`
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft) [23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中 [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数 [23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调 [23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数 [23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)
[23/07/31] 我们支持了**数据流式加载**。请使用 `--streaming``--max_steps 10000` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft)。
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
</details>
## 模型 ## 模型
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) | 模型名 | 模型大小 | 默认模块 | Template |
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B) | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B) | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
- [InternLM](https://github.com/InternLM/InternLM) (7B) | [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
## 微调方法 > [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) 项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)
- 全参数微调
- 部分参数微调 ## 训练方法
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314) | 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
- [指令监督微调](https://arxiv.org/abs/2109.01652) | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
- 全参数微调 | 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- 部分参数微调 | 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [LoRA](https://arxiv.org/abs/2106.09685) | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [QLoRA](https://arxiv.org/abs/2305.14314) | PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [人类反馈的强化学习RLHF](https://arxiv.org/abs/2203.02155) | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314) > [!NOTE]
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
## 数据集 ## 数据集
- 用于二次预训练: <details><summary>预训练数据集</summary>
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
使用方法请参考 [data/README.md](data/README_zh.md) 文件。 - [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
</details>
<details><summary>指令微调数据集</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self Cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
</details>
<details><summary>偏好数据集</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details>
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。 部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
@@ -107,32 +223,41 @@ pip install --upgrade huggingface_hub
huggingface-cli login huggingface-cli login
``` ```
## 软件依赖 ## 软件依赖
- Python 3.8+ 和 PyTorch 1.13.1+ - Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL - 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- jieba, rouge-chinese 和 nltk (用于评估) - sentencepiece, protobuf 和 tiktoken
- jieba, rouge-chinese 和 nltk (用于评估及预测)
- gradio 和 matplotlib (用于网页端交互) - gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API) - uvicorn, fastapi 和 sse-starlette (用于 API)
以及 **强而有力的 GPU** ### 硬件依赖
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
## 如何使用 ## 如何使用
### 数据准备(可跳过) ### 数据准备(可跳过)
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。 关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md` > [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。
### 环境搭建(可跳过) ### 环境搭建(可跳过)
```bash ```bash
git lfs install git clone https://github.com/hiyouga/LLaMA-Factory.git
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_factory python=3.10
conda create -n llama_etuning python=3.10 conda activate llama_factory
conda activate llama_etuning cd LLaMA-Factory
cd LLaMA-Efficient-Tuning
pip install -r requirements.txt pip install -r requirements.txt
``` ```
@@ -142,24 +267,43 @@ pip install -r requirements.txt
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
``` ```
### 浏览器一键微调/测试 ### 使用魔搭社区(可跳过)
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
``` ```
目前网页 UI 仅支持**单卡训练**。 接着即可通过指定模型名称来训练对应的模型。(在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型)
### 二次预训练 ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path modelscope/Llama-2-7b-ms \
... # 参数同上
```
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```
### 单 GPU 训练
> [!IMPORTANT]
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
#### 预训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \ --stage pt \
--model_name_or_path path_to_your_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \ --dataset wiki_demo \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \ --output_dir path_to_pt_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -173,16 +317,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### 指令监督微调 #### 指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_zh \ --dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \ --output_dir path_to_sft_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -196,22 +341,77 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。 #### 奖励模型训练
### 奖励模型训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \ --stage rm \
--model_name_or_path path_to_your_model \
--do_train \ --do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \ --dataset comparison_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \ --lora_target q_proj,v_proj \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-6 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--top_k 0 \
--top_p 0.9 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
#### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
@@ -222,50 +422,22 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### RLHF 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
```
### 多 GPU 分布式训练 ### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate
```bash ```bash
accelerate config # 首先配置分布式环境 accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上 accelerate launch src/train_bash.py # 参数同上
``` ```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 Accelerate 配置示例</summary> <details><summary>LoRA 训练的 Accelerate 配置示例</summary>
```yaml ```yaml
compute_environment: LOCAL_MACHINE compute_environment: LOCAL_MACHINE
deepspeed_config: distributed_type: MULTI_GPU
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no' downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
@@ -281,121 +453,170 @@ use_cpu: false
</details> </details>
### 指标评估BLEU分数和汉语ROUGE分数 #### 使用 DeepSpeed
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--stage sft \ --deepspeed ds_config.json \
--model_name_or_path path_to_your_model \ ... # 参数同上
--do_eval \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
``` ```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128` 参数。 <details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例</summary>
### 模型预测 ```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### 合并 LoRA 权重并导出模型
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ python src/export_model.py \
--stage sft \ --model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \ --adapter_name_or_path path_to_checkpoint \
--do_predict \
--dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --export_dir path_to_export \
--output_dir path_to_predict_result \ --export_size 2 \
--per_device_eval_batch_size 8 \ --export_legacy_format False
--max_samples 100 \
--predict_with_generate
``` ```
> [!WARNING]
> 尚不支持量化模型的 LoRA 权重合并及导出。
> [!TIP]
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化模型。
### API 服务 ### API 服务
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
关于 API 文档请见 `http://localhost:8000/docs` > [!TIP]
> 关于 API 文档请见 `http://localhost:8000/docs`。
### 命令行测试 ### 命令行测试
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### 浏览器测试 ### 浏览器测试
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora
--checkpoint_dir path_to_checkpoint
``` ```
### 导出微调模型 ### 模型评估
```bash ```bash
python src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--template default \ --adapter_name_or_path path_to_checkpoint \
--template vanilla \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --task ceval \
--output_dir path_to_export --split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
``` ```
## TODO ### 模型预测
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。 ```bash
- [ ] 在推理阶段使用 Multi-query attention 进行加速。 CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
- [ ] 支持 RLHF 的全参数微调。 --stage sft \
--do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate \
--fp16
```
> [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的预测,请使用 `--per_device_eval_batch_size=1`。
> [!TIP]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
## 使用了 LLaMA Factory 的项目
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
> [!TIP]
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
## 协议 ## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议: 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
## 引用 ## 引用
如果您觉得此项目有帮助,请考虑以下列格式引用 如果您觉得此项目有帮助,请考虑以下列格式引用
```bibtex ```bibtex
@Misc{llama-efficient-tuning, @Misc{llama-factory,
title = {LLaMA Efficient Tuning}, title = {LLaMA Factory},
author = {hiyouga}, author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}}, howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
year = {2023} year = {2023}
} }
``` ```
## 致谢 ## 致谢
本项目是 [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) 的同类项目。采用了类似的代码结构和训练方法 本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出
## Star History ## Star History
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Efficient-Tuning&type=Date) ![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)

1216
assets/benchmark.svg Normal file

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 29 KiB

View File

@@ -2,17 +2,127 @@ If you are using a custom dataset, please provide your dataset definition in the
```json ```json
"dataset_name": { "dataset_name": {
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)", "hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", "ms_hub_url": "the name of the dataset repository on the ModelScope hub. (if specified, ignore script_url and file_name)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)", "file_name": "the name of the dataset file in this directory. (required if above are not specified)",
"columns": { "file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)", "subset": "the name of the subset. (optional, default: None)",
"query": "the name of the column in the datasets containing the queries. (default: input)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"response": "the name of the column in the datasets containing the responses. (default: output)", "ranking": "whether the dataset is a preference dataset or not. (default: false)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)" "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns (optional)": {
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
"query": "the column name in the dataset containing the queries. (default: input)",
"response": "the column name in the dataset containing the responses. (default: output)",
"history": "the column name in the dataset containing the histories. (default: None)",
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
"content_tag": "the key in the message represents the content. (default: value)",
"user_tag": "the value of the role_tag represents the user. (default: human)",
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
} }
} }
``` ```
where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair. Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
```json
[
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"output": "model response (required)",
"system": "system prompt (optional)",
"history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
```json
"dataset_name": {
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"system": "system",
"history": "history"
}
}
```
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
For the pre-training datasets, only the `prompt` column will be used for training.
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
```json
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
}
```
The dataset in sharegpt format should follow the below format:
```json
[
{
"conversations": [
{
"from": "human",
"value": "user instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"system": "system prompt (optional)",
"tools": "tool description (optional)"
}
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
```json
"dataset_name": {
"columns": {
"messages": "conversations",
"system": "system",
"tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
}
}
```
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.

View File

@@ -1,18 +1,128 @@
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以下格式提供您的数据集定义。 如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中按照以下格式提供数据集定义。
```json ```json
"数据集名称": { "数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数", "hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数", "ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选", "file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练",
"columns": { "subset": "数据集子集的名称可选默认None",
"folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns可选": {
"prompt": "数据集代表提示词的表头名称默认instruction", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input", "query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output", "response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None" "history": "数据集代表历史对话的表头名称默认None",
"messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
"content_tag": "消息中代表文本内容的键名默认value",
"user_tag": "消息中代表用户的 role_tag默认human",
"assistant_tag": "消息中代表助手的 role_tag默认gpt",
"observation_tag": "消息中代表工具返回结果的 role_tag默认observation",
"function_tag": "消息中代表工具调用的 role_tag默认function_call",
"system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system 列)"
} }
} }
``` ```
其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复 添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
```json
[
{
"instruction": "用户指令(必填)",
"input": "用户输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"]
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
```json
"数据集名称": {
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"system": "system",
"history": "history"
}
}
```
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答。
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**也会被用于训练**。
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
```json
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"优质回答",
"劣质回答"
]
}
```
而 sharegpt 格式的数据集按照以下方式组织:
```json
[
{
"conversations": [
{
"from": "human",
"value": "用户指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"system": "系统提示词(选填)",
"tools": "工具描述(选填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
```json
"数据集名称": {
"columns": {
"messages": "conversations",
"system": "system",
"tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
}
}
```
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
预训练数据集和偏好数据集尚不支持 sharegpt 格式。

View File

@@ -1 +1 @@
fc9a6a3458caca2af8dafc6181773fe10c6d8657 34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b

View File

@@ -1,6 +1,5 @@
import json import json
import datasets import datasets
from typing import Any, Dict, List
_DESCRIPTION = "BELLE multiturn chat dataset." _DESCRIPTION = "BELLE multiturn chat dataset."
@@ -23,11 +22,9 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo: def _info(self):
features = datasets.Features({ features = datasets.Features({
"instruction": datasets.Value("string"), "conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
}) })
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
@@ -37,7 +34,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
citation=_CITATION citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download(_URL) file_path = dl_manager.download(_URL)
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
@@ -48,10 +45,11 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
) )
] ]
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat with history def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f): for key, row in enumerate(f):
data = json.loads(row) data = json.loads(row)
conversations = []
prompt = data["instruction"].strip() prompt = data["instruction"].strip()
response = data["output"].strip() response = data["output"].strip()
@@ -59,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
human_idx = prompt.rfind("Human:") human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip() query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip() prompt = prompt[:human_idx].strip()
history = [] conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
while prompt.rfind("Assistant:") != -1: while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:") assist_idx = prompt.rfind("Assistant:")
@@ -67,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
if human_idx != -1: if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip() old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip() old_resp = prompt[assist_idx+10:].strip()
history.insert(0, (old_query, old_resp)) conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else: else:
break break
prompt = prompt[:human_idx].strip() prompt = prompt[:human_idx].strip()
yield key, { yield key, {"conversations": conversations}
"instruction": query,
"output": response,
"history": history
}

View File

@@ -3,7 +3,7 @@ import datasets
from typing import Any, Dict, List from typing import Any, Dict, List
_DESCRIPTION = "An example of dataset for LLaMA." _DESCRIPTION = "An example of dataset."
_CITATION = "" _CITATION = ""
_HOMEPAGE = "" _HOMEPAGE = ""
_LICENSE = "" _LICENSE = ""

View File

@@ -0,0 +1 @@
4748dff00d1dc42768a5b6cc772143c313017812

View File

@@ -1,9 +1,9 @@
import json import json
import datasets import datasets
from typing import Any, Dict, List from typing import List
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM." _DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = "" _CITATION = ""
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf" _HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit" _LICENSE = "mit"
@@ -42,7 +42,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
citation=_CITATION citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download_and_extract(_URLS) file_path = dl_manager.download_and_extract(_URLS)
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
@@ -59,7 +59,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
) )
] ]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM def _generate_examples(self, filepaths: List[str]):
key = 0 key = 0
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:

View File

@@ -1 +0,0 @@
f967a4f6d04a11308a15524aa9a846a19a8d1e83

View File

@@ -1 +0,0 @@
0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8

View File

@@ -1 +0,0 @@
38c89869c6aeca2a3af9ea1e09afe460f9b46810

View File

@@ -1,6 +1,6 @@
import json import json
import datasets import datasets
from typing import Any, Dict, List from typing import List
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data." _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
@@ -21,15 +21,13 @@ _LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl" _BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
class BelleMultiturn(datasets.GeneratorBasedBuilder): class UltraChat(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo: def _info(self):
features = datasets.Features({ features = datasets.Features({
"instruction": datasets.Value("string"), "conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
}) })
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
@@ -39,8 +37,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
citation=_CITATION citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: def _split_generators(self, dl_manager: datasets.DownloadManager):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
@@ -50,7 +48,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
) )
] ]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:
for row in f: for row in f:
@@ -58,19 +56,14 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
data = json.loads(row) data = json.loads(row)
except: except:
continue continue
key = data["id"] key: int = data["id"]
content = data["data"] content: List[str] = data["data"]
if len(content) % 2 == 1: if len(content) % 2 == 1:
content.pop(-1) content.pop(-1)
if len(content) < 2: if len(content) < 2:
continue continue
conversations = [{
query = content[-2] "from": "human" if i % 2 == 0 else "gpt",
response = content[-1] "value": content[i]
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)] } for i in range(len(content))]
yield key, {"conversations": conversations}
yield key, {
"instruction": query,
"output": response,
"history": history
}

View File

@@ -1,50 +0,0 @@
Machine learning (ML) is a field devoted to understanding and building methods that let machines "learn" that is, methods that leverage data to improve computer performance on some set of tasks.
Machine learning algorithms build a model based on sample data, known as training data, in order to make predictions or decisions without being explicitly programmed to do so. Machine learning algorithms are used in a wide variety of applications, such as in medicine, email filtering, speech recognition, agriculture, and computer vision, where it is difficult or unfeasible to develop conventional algorithms to perform the needed tasks.
A subset of machine learning is closely related to computational statistics, which focuses on making predictions using computers, but not all machine learning is statistical learning. The study of mathematical optimization delivers methods, theory and application domains to the field of machine learning. Data mining is a related field of study, focusing on exploratory data analysis through unsupervised learning.
Some implementations of machine learning use data and neural networks in a way that mimics the working of a biological brain.
In its application across business problems, machine learning is also referred to as predictive analytics.
Learning algorithms work on the basis that strategies, algorithms, and inferences that worked well in the past are likely to continue working well in the future. These inferences can sometimes be obvious, such as "since the sun rose every morning for the last 10,000 days, it will probably rise tomorrow morning as well". Other times, they can be more nuanced, such as "X% of families have geographically separate species with color variants, so there is a Y% chance that undiscovered black swans exist".
Machine learning programs can perform tasks without being explicitly programmed to do so. It involves computers learning from data provided so that they carry out certain tasks. For simple tasks assigned to computers, it is possible to program algorithms telling the machine how to execute all steps required to solve the problem at hand; on the computer's part, no learning is needed. For more advanced tasks, it can be challenging for a human to manually create the needed algorithms. In practice, it can turn out to be more effective to help the machine develop its own algorithm, rather than having human programmers specify every needed step.
The discipline of machine learning employs various approaches to teach computers to accomplish tasks where no fully satisfactory algorithm is available. In cases where vast numbers of potential answers exist, one approach is to label some of the correct answers as valid. This can then be used as training data for the computer to improve the algorithm(s) it uses to determine correct answers. For example, to train a system for the task of digital character recognition, the MNIST dataset of handwritten digits has often been used.
The term machine learning was coined in 1959 by Arthur Samuel, an IBM employee and pioneer in the field of computer gaming and artificial intelligence. The synonym self-teaching computers was also used in this time period.
By the early 1960s an experimental "learning machine" with punched tape memory, called Cybertron, had been developed by Raytheon Company to analyze sonar signals, electrocardiograms, and speech patterns using rudimentary reinforcement learning. It was repetitively "trained" by a human operator/teacher to recognize patterns and equipped with a "goof" button to cause it to re-evaluate incorrect decisions. A representative book on research into machine learning during the 1960s was Nilsson's book on Learning Machines, dealing mostly with machine learning for pattern classification. Interest related to pattern recognition continued into the 1970s, as described by Duda and Hart in 1973. In 1981 a report was given on using teaching strategies so that a neural network learns to recognize 40 characters (26 letters, 10 digits, and 4 special symbols) from a computer terminal.
Tom M. Mitchell provided a widely quoted, more formal definition of the algorithms studied in the machine learning field: "A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P if its performance at tasks in T, as measured by P, improves with experience E." This definition of the tasks in which machine learning is concerned offers a fundamentally operational definition rather than defining the field in cognitive terms. This follows Alan Turing's proposal in his paper "Computing Machinery and Intelligence", in which the question "Can machines think?" is replaced with the question "Can machines do what we (as thinking entities) can do?".
Modern-day machine learning has two objectives, one is to classify data based on models which have been developed, the other purpose is to make predictions for future outcomes based on these models. A hypothetical algorithm specific to classifying data may use computer vision of moles coupled with supervised learning in order to train it to classify the cancerous moles. A machine learning algorithm for stock trading may inform the trader of future potential predictions.
As a scientific endeavor, machine learning grew out of the quest for artificial intelligence (AI). In the early days of AI as an academic discipline, some researchers were interested in having machines learn from data. They attempted to approach the problem with various symbolic methods, as well as what were then termed "neural networks"; these were mostly perceptrons and other models that were later found to be reinventions of the generalized linear models of statistics. Probabilistic reasoning was also employed, especially in automated medical diagnosis.:488
However, an increasing emphasis on the logical, knowledge-based approach caused a rift between AI and machine learning. Probabilistic systems were plagued by theoretical and practical problems of data acquisition and representation.:488 By 1980, expert systems had come to dominate AI, and statistics was out of favor. Work on symbolic/knowledge-based learning did continue within AI, leading to inductive logic programming, but the more statistical line of research was now outside the field of AI proper, in pattern recognition and information retrieval.:708710,755 Neural networks research had been abandoned by AI and computer science around the same time. This line, too, was continued outside the AI/CS field, as "connectionism", by researchers from other disciplines including Hopfield, Rumelhart, and Hinton. Their main success came in the mid-1980s with the reinvention of backpropagation.:25
Machine learning (ML), reorganized and recognized as its own field, started to flourish in the 1990s. The field changed its goal from achieving artificial intelligence to tackling solvable problems of a practical nature. It shifted focus away from the symbolic approaches it had inherited from AI, and toward methods and models borrowed from statistics, fuzzy logic, and probability theory.
Machine learning and data mining often employ the same methods and overlap significantly, but while machine learning focuses on prediction, based on known properties learned from the training data, data mining focuses on the discovery of (previously) unknown properties in the data (this is the analysis step of knowledge discovery in databases). Data mining uses many machine learning methods, but with different goals; on the other hand, machine learning also employs data mining methods as "unsupervised learning" or as a preprocessing step to improve learner accuracy. Much of the confusion between these two research communities (which do often have separate conferences and separate journals, ECML PKDD being a major exception) comes from the basic assumptions they work with: in machine learning, performance is usually evaluated with respect to the ability to reproduce known knowledge, while in knowledge discovery and data mining (KDD) the key task is the discovery of previously unknown knowledge. Evaluated with respect to known knowledge, an uninformed (unsupervised) method will easily be outperformed by other supervised methods, while in a typical KDD task, supervised methods cannot be used due to the unavailability of training data.
Machine learning also has intimate ties to optimization: many learning problems are formulated as minimization of some loss function on a training set of examples. Loss functions express the discrepancy between the predictions of the model being trained and the actual problem instances (for example, in classification, one wants to assign a label to instances, and models are trained to correctly predict the pre-assigned labels of a set of examples).
The difference between optimization and machine learning arises from the goal of generalization: while optimization algorithms can minimize the loss on a training set, machine learning is concerned with minimizing the loss on unseen samples. Characterizing the generalization of various learning algorithms is an active topic of current research, especially for deep learning algorithms.
Machine learning and statistics are closely related fields in terms of methods, but distinct in their principal goal: statistics draws population inferences from a sample, while machine learning finds generalizable predictive patterns. According to Michael I. Jordan, the ideas of machine learning, from methodological principles to theoretical tools, have had a long pre-history in statistics. He also suggested the term data science as a placeholder to call the overall field.
Leo Breiman distinguished two statistical modeling paradigms: data model and algorithmic model, wherein "algorithmic model" means more or less the machine learning algorithms like Random Forest.
Some statisticians have adopted methods from machine learning, leading to a combined field that they call statistical learning.
Analytical and computational techniques derived from deep-rooted physics of disordered systems can be extended to large-scale problems, including machine learning, e.g., to analyze the weight space of deep neural networks. Statistical physics is thus finding applications in the area of medical diagnostics.
A core objective of a learner is to generalize from its experience. Generalization in this context is the ability of a learning machine to perform accurately on new, unseen examples/tasks after having experienced a learning data set. The training examples come from some generally unknown probability distribution (considered representative of the space of occurrences) and the learner has to build a general model about this space that enables it to produce sufficiently accurate predictions in new cases.
The computational analysis of machine learning algorithms and their performance is a branch of theoretical computer science known as computational learning theory via the Probably Approximately Correct Learning (PAC) model. Because training sets are finite and the future is uncertain, learning theory usually does not yield guarantees of the performance of algorithms. Instead, probabilistic bounds on the performance are quite common. The biasvariance decomposition is one way to quantify generalization error.
For the best performance in the context of generalization, the complexity of the hypothesis should match the complexity of the function underlying the data. If the hypothesis is less complex than the function, then the model has under fitted the data. If the complexity of the model is increased in response, then the training error decreases. But if the hypothesis is too complex, then the model is subject to overfitting and generalization will be poorer.
In addition to performance bounds, learning theorists study the time complexity and feasibility of learning. In computational learning theory, a computation is considered feasible if it can be done in polynomial time. There are two kinds of time complexity results: Positive results show that a certain class of functions can be learned in polynomial time. Negative results show that certain classes cannot be learned in polynomial time.
Machine learning approaches are traditionally divided into three broad categories, which correspond to learning paradigms, depending on the nature of the "signal" or "feedback" available to the learning system:
Supervised learning: The computer is presented with example inputs and their desired outputs, given by a "teacher", and the goal is to learn a general rule that maps inputs to outputs.
Unsupervised learning: No labels are given to the learning algorithm, leaving it on its own to find structure in its input. Unsupervised learning can be a goal in itself (discovering hidden patterns in data) or a means towards an end (feature learning).
Reinforcement learning: A computer program interacts with a dynamic environment in which it must perform a certain goal (such as driving a vehicle or playing a game against an opponent). As it navigates its problem space, the program is provided feedback that's analogous to rewards, which it tries to maximize. Although each algorithm has advantages and limitations, no single algorithm works for all problems.
Supervised learning algorithms build a mathematical model of a set of data that contains both the inputs and the desired outputs. The data is known as training data, and consists of a set of training examples. Each training example has one or more inputs and the desired output, also known as a supervisory signal. In the mathematical model, each training example is represented by an array or vector, sometimes called a feature vector, and the training data is represented by a matrix. Through iterative optimization of an objective function, supervised learning algorithms learn a function that can be used to predict the output associated with new inputs. An optimal function will allow the algorithm to correctly determine the output for inputs that were not a part of the training data. An algorithm that improves the accuracy of its outputs or predictions over time is said to have learned to perform that task.
Types of supervised-learning algorithms include active learning, classification and regression. Classification algorithms are used when the outputs are restricted to a limited set of values, and regression algorithms are used when the outputs may have any numerical value within a range. As an example, for a classification algorithm that filters emails, the input would be an incoming email, and the output would be the name of the folder in which to file the email.
Similarity learning is an area of supervised machine learning closely related to regression and classification, but the goal is to learn from examples using a similarity function that measures how similar or related two objects are. It has applications in ranking, recommendation systems, visual identity tracking, face verification, and speaker verification.
Unsupervised learning algorithms take a set of data that contains only inputs, and find structure in the data, like grouping or clustering of data points. The algorithms, therefore, learn from test data that has not been labeled, classified or categorized. Instead of responding to feedback, unsupervised learning algorithms identify commonalities in the data and react based on the presence or absence of such commonalities in each new piece of data. A central application of unsupervised learning is in the field of density estimation in statistics, such as finding the probability density function. Though unsupervised learning encompasses other domains involving summarizing and explaining data features. Unsupervised learning algorithms streamlined the process of survey and graph large indel based haplotypes of a gene of interest from pan-genome.
Cluster analysis is the assignment of a set of observations into subsets (called clusters) so that observations within the same cluster are similar according to one or more predesignated criteria, while observations drawn from different clusters are dissimilar. Different clustering techniques make different assumptions on the structure of the data, often defined by some similarity metric and evaluated, for example, by internal compactness, or the similarity between members of the same cluster, and separation, the difference between clusters. Other methods are based on estimated density and graph connectivity.
Semi-supervised learning falls between unsupervised learning (without any labeled training data) and supervised learning (with completely labeled training data). Some of the training examples are missing training labels, yet many machine-learning researchers have found that unlabeled data, when used in conjunction with a small amount of labeled data, can produce a considerable improvement in learning accuracy.
In weakly supervised learning, the training labels are noisy, limited, or imprecise; however, these labels are often cheaper to obtain, resulting in larger effective training sets.
Reinforcement learning is an area of machine learning concerned with how software agents ought to take actions in an environment so as to maximize some notion of cumulative reward. Due to its generality, the field is studied in many other disciplines, such as game theory, control theory, operations research, information theory, simulation-based optimization, multi-agent systems, swarm intelligence, statistics and genetic algorithms. In machine learning, the environment is typically represented as a Markov decision process (MDP). Many reinforcements learning algorithms use dynamic programming techniques. Reinforcement learning algorithms do not assume knowledge of an exact mathematical model of the MDP and are used when exact models are infeasible. Reinforcement learning algorithms are used in autonomous vehicles or in learning to play a game against a human opponent.
Dimensionality reduction is a process of reducing the number of random variables under consideration by obtaining a set of principal variables. In other words, it is a process of reducing the dimension of the feature set, also called the "number of features". Most of the dimensionality reduction techniques can be considered as either feature elimination or extraction. One of the popular methods of dimensionality reduction is principal component analysis (PCA). PCA involves changing higher-dimensional data (e.g., 3D) to a smaller space (e.g., 2D). This results in a smaller dimension of data (2D instead of 3D), while keeping all original variables in the model without changing the data. The manifold hypothesis proposes that high-dimensional data sets lie along low-dimensional manifolds, and many dimensionality reduction techniques make this assumption, leading to the area of manifold learning and manifold regularization.
Although machine learning has been transformative in some fields, machine-learning programs often fail to deliver expected results. Reasons for this are numerous: lack of (suitable) data, lack of access to the data, data bias, privacy problems, badly chosen tasks and algorithms, wrong tools and people, lack of resources, and evaluation problems.
In 2018, a self-driving car from Uber failed to detect a pedestrian, who was killed after a collision. Attempts to use machine learning in healthcare with the IBM Watson system failed to deliver even after years of time and billions of dollars invested.
Machine learning has been used as a strategy to update the evidence related to a systematic review and increased reviewer burden related to the growth of biomedical literature. While it has improved with training sets, it has not yet developed sufficiently to reduce the workload burden without limiting the necessary sensitivity for the findings research themselves.
Machine learning approaches in particular can suffer from different data biases. A machine learning system trained specifically on current customers may not be able to predict the needs of new customer groups that are not represented in the training data. When trained on human-made data, machine learning is likely to pick up the constitutional and unconscious biases already present in society. Language models learned from data have been shown to contain human-like biases. Machine learning systems used for criminal risk assessment have been found to be biased against black people. In 2015, Google photos would often tag black people as gorillas, and in 2018 this still was not well resolved, but Google reportedly was still using the workaround to remove all gorillas from the training data, and thus was not able to recognize real gorillas at all. Similar issues with recognizing non-white people have been found in many other systems. In 2016, Microsoft tested a chatbot that learned from Twitter, and it quickly picked up racist and sexist language. Because of such challenges, the effective use of machine learning may take longer to be adopted in other domains. Concern for fairness in machine learning, that is, reducing bias in machine learning and propelling its use for human good is increasingly expressed by artificial intelligence scientists, including Fei-Fei Li, who reminds engineers that "There's nothing artificial about AI...It's inspired by people, it's created by people, and—most importantly—it impacts people. It is a powerful tool we are only just beginning to understand, and that is a profound responsibility."
Learners can also disappoint by "learning the wrong lesson". A toy example is that an image classifier trained only on pictures of brown horses and black cats might conclude that all brown patches are likely to be horses. A real-world example is that, unlike humans, current image classifiers often do not primarily make judgments from the spatial relationship between components of the picture, and they learn relationships between pixels that humans are oblivious to, but that still correlate with images of certain types of real objects. Modifying these patterns on a legitimate image can result in "adversarial" images that the system misclassifies.
Adversarial vulnerabilities can also result in nonlinear systems, or from non-pattern perturbations. Some systems are so brittle that changing a single adversarial pixel predictably induces misclassification.[citation needed] Machine learning models are often vulnerable to manipulation and/or evasion via adversarial machine learning.
Researchers have demonstrated how backdoors can be placed undetectably into classifying (e.g., for categories "spam" and well-visible "not spam" of posts) machine learning models which are often developed and/or trained by third parties. Parties can change the classification of any input, including in cases for which a type of data/software transparency is provided, possibly including white-box access.
Machine learning poses a host of ethical questions. Systems that are trained on datasets collected with biases may exhibit these biases upon use (algorithmic bias), thus digitizing cultural prejudices. For example, in 1988, the UK's Commission for Racial Equality found that St. George's Medical School had been using a computer program trained from data of previous admissions staff and this program had denied nearly 60 candidates who were found to be either women or had non-European sounding names. Using job hiring data from a firm with racist hiring policies may lead to a machine learning system duplicating the bias by scoring job applicants by similarity to previous successful applicants. Responsible collection of data and documentation of algorithmic rules used by a system thus is a critical part of machine learning.
AI can be well-equipped to make decisions in technical fields, which rely heavily on data and historical information. These decisions rely on the objectivity and logical reasoning. Because human languages contain biases, machines trained on language corpora will necessarily also learn these biases.
Other forms of ethical challenges, not related to personal biases, are seen in health care. There are concerns among health care professionals that these systems might not be designed in the public's interest but as income-generating machines. This is especially true in the United States where there is a long-standing ethical dilemma of improving health care, but also increase profits. For example, the algorithms could be designed to provide patients with unnecessary tests or medication in which the algorithm's proprietary owners hold stakes. There is potential for machine learning in health care to provide professionals an additional tool to diagnose, medicate, and plan recovery paths for patients, but this requires these biases to be mitigated.
Since the 2010s, advances in both machine learning algorithms and computer hardware have led to more efficient methods for training deep neural networks (a particular narrow subdomain of machine learning) that contain many layers of non-linear hidden units. By 2019, graphic processing units (GPUs), often with AI-specific enhancements, had displaced CPUs as the dominant method of training large-scale commercial cloud AI. OpenAI estimated the hardware computing used in the largest deep learning projects from AlexNet (2012) to AlphaZero (2017), and found a 300,000-fold increase in the amount of compute required, with a doubling-time trendline of 3.4 months.

View File

@@ -0,0 +1 @@
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb

166
evaluation/ceval/ceval.py Normal file
View File

@@ -0,0 +1,166 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{huang2023ceval,
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
journal={arXiv preprint arXiv:2305.08322},
year={2023}
}
"""
_DESCRIPTION = """\
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
"""
_HOMEPAGE = "https://cevalbenchmark.com"
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
_URL = "ceval.zip"
task_list = [
"computer_network",
"operating_system",
"computer_architecture",
"college_programming",
"college_physics",
"college_chemistry",
"advanced_mathematics",
"probability_and_statistics",
"discrete_mathematics",
"electrical_engineer",
"metrology_engineer",
"high_school_mathematics",
"high_school_physics",
"high_school_chemistry",
"high_school_biology",
"middle_school_mathematics",
"middle_school_biology",
"middle_school_physics",
"middle_school_chemistry",
"veterinary_medicine",
"college_economics",
"business_administration",
"marxism",
"mao_zedong_thought",
"education_science",
"teacher_qualification",
"high_school_politics",
"high_school_geography",
"middle_school_politics",
"middle_school_geography",
"modern_chinese_history",
"ideological_and_moral_cultivation",
"logic",
"law",
"chinese_language_and_literature",
"art_studies",
"professional_tour_guide",
"legal_professional",
"high_school_chinese",
"high_school_history",
"middle_school_history",
"civil_servant",
"sports_science",
"plant_protection",
"basic_medicine",
"clinical_medicine",
"urban_and_rural_planner",
"accountant",
"fire_engineer",
"environmental_impact_assessment_engineer",
"tax_accountant",
"physician",
]
class CevalConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
class Ceval(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CevalConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"id": datasets.Value("int32"),
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
"explanation": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(
data_dir, "test", f"{task_name}_test.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": os.path.join(
data_dir, "val", f"{task_name}_val.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(
data_dir, "dev", f"{task_name}_dev.csv"
),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath, encoding="utf-8")
for i, instance in enumerate(df.to_dict(orient="records")):
if "answer" not in instance.keys():
instance["answer"] = ""
if "explanation" not in instance.keys():
instance["explanation"] = ""
yield i, instance

167
evaluation/cmmlu/cmmlu.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
journal={arXiv preprint arXiv:2306.09212},
year={2023}
}
"""
_DESCRIPTION = """\
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
"""
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
_URL = "cmmlu.zip"
task_list = [
'agronomy',
'anatomy',
'ancient_chinese',
'arts',
'astronomy',
'business_ethics',
'chinese_civil_service_exam',
'chinese_driving_rule',
'chinese_food_culture',
'chinese_foreign_policy',
'chinese_history',
'chinese_literature',
'chinese_teacher_qualification',
'clinical_knowledge',
'college_actuarial_science',
'college_education',
'college_engineering_hydrology',
'college_law',
'college_mathematics',
'college_medical_statistics',
'college_medicine',
'computer_science',
'computer_security',
'conceptual_physics',
'construction_project_management',
'economics',
'education',
'electrical_engineering',
'elementary_chinese',
'elementary_commonsense',
'elementary_information_and_technology',
'elementary_mathematics',
'ethnology',
'food_science',
'genetics',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_geography',
'high_school_mathematics',
'high_school_physics',
'high_school_politics',
'human_sexuality',
'international_law',
'journalism',
'jurisprudence',
'legal_and_moral_basis',
'logical',
'machine_learning',
'management',
'marketing',
'marxist_theory',
'modern_chinese',
'nutrition',
'philosophy',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_study',
'sociology',
'sports_science',
'traditional_chinese_medicine',
'virology',
'world_history',
'world_religions',
]
class CMMLUConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.1"), **kwargs)
class CMMLU(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CMMLUConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(data_dir, f"test/{task_name}.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(data_dir, f"dev/{task_name}.csv"),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8")
for i, instance in enumerate(df.to_dict(orient="records")):
question = instance.pop("Question", "")
answer = instance.pop("Answer", "")
instance["question"] = question
instance["answer"] = answer
yield i, instance

167
evaluation/mmlu/mmlu.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2021}
}
"""
_DESCRIPTION = """\
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
"""
_HOMEPAGE = "https://github.com/hendrycks/test"
_LICENSE = "MIT"
_URL = "mmlu.zip"
task_list = [
"high_school_european_history",
"business_ethics",
"clinical_knowledge",
"medical_genetics",
"high_school_us_history",
"high_school_physics",
"high_school_world_history",
"virology",
"high_school_microeconomics",
"econometrics",
"college_computer_science",
"high_school_biology",
"abstract_algebra",
"professional_accounting",
"philosophy",
"professional_medicine",
"nutrition",
"global_facts",
"machine_learning",
"security_studies",
"public_relations",
"professional_psychology",
"prehistory",
"anatomy",
"human_sexuality",
"college_medicine",
"high_school_government_and_politics",
"college_chemistry",
"logical_fallacies",
"high_school_geography",
"elementary_mathematics",
"human_aging",
"college_mathematics",
"high_school_psychology",
"formal_logic",
"high_school_statistics",
"international_law",
"high_school_mathematics",
"high_school_computer_science",
"conceptual_physics",
"miscellaneous",
"high_school_chemistry",
"marketing",
"professional_law",
"management",
"college_physics",
"jurisprudence",
"world_religions",
"sociology",
"us_foreign_policy",
"high_school_macroeconomics",
"computer_security",
"moral_scenarios",
"moral_disputes",
"electrical_engineering",
"astronomy",
"college_biology",
]
class MMLUConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
class MMLU(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
MMLUConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "test", f"{task_name}_test.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "val", f"{task_name}_val.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "dev", f"{task_name}_dev.csv"
),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath)
df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")):
yield i, instance

View File

@@ -1,3 +1,46 @@
[build-system] [build-system]
requires = ["setuptools>=61.0"] requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.black]
line-length = 119
target-version = ["py38"]
[tool.ruff]
line-length = 119
indent-width = 4
[tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["llmtuner"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
[isort]
default_section = "FIRSTPARTY"
known_first_party = "llmtuner"
known_third_party = [
"accelerate",
"datasets",
"gradio",
"numpy",
"peft",
"torch",
"transformers",
"trl"
]
line_length = 119
lines_after_imports = 2
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true

View File

@@ -1,16 +1,19 @@
torch>=1.13.1 torch>=1.13.1
transformers>=4.29.1 transformers>=4.37.2
datasets>=2.12.0 datasets>=2.14.3
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.8.2
trl>=0.4.7 trl>=0.7.6
gradio>=3.38.0,<4.0.0
scipy
einops
sentencepiece sentencepiece
protobuf
jieba jieba
rouge-chinese rouge-chinese
nltk nltk
gradio>=3.36.0
uvicorn uvicorn
pydantic==1.10.11 pydantic
fastapi==0.95.1 fastapi
sse-starlette sse-starlette
matplotlib matplotlib

View File

@@ -25,12 +25,12 @@ def main():
version=get_version(), version=get_version(),
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use fine-tuning framework using PEFT", description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License", license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Efficient-Tuning", url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages("src"), packages=find_packages("src"),
python_requires=">=3.8.0", python_requires=">=3.8.0",

View File

@@ -1,19 +1,15 @@
# coding=utf-8 import os
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
import uvicorn import uvicorn
from llmtuner import ChatModel from llmtuner import ChatModel, create_app
from llmtuner.api.app import create_app
from llmtuner.tuner import get_infer_args
def main(): def main():
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,14 +1,19 @@
# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel from llmtuner import ChatModel
from llmtuner.tuner import get_infer_args from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
def main(): def main():
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
history = [] messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True: while True:
@@ -24,19 +29,20 @@ def main():
break break
if query.strip() == "clear": if query.strip() == "clear":
history = [] messages = []
torch_gc()
print("History has been removed.") print("History has been removed.")
continue continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True) print("Assistant: ", end="", flush=True)
response = "" response = ""
for new_text in chat_model.stream_chat(query, history): for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True) print(new_text, end="", flush=True)
response += new_text response += new_text
print() print()
messages.append({"role": "assistant", "content": response})
history = history + [(query, response)]
if __name__ == "__main__": if __name__ == "__main__":

10
src/evaluate.py Normal file
View File

@@ -0,0 +1,10 @@
from llmtuner import Evaluator
def main():
evaluator = Evaluator()
evaluator.eval()
if __name__ == "__main__":
main()

View File

@@ -1,16 +1,8 @@
# coding=utf-8 from llmtuner import export_model
# Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
def main(): def main():
model_args, _, training_args, finetuning_args, _ = get_train_args() export_model()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,4 +1,11 @@
from llmtuner.chat import ChatModel # Level: api, webui > chat, eval, train > data, model > extras, hparams
from .api import create_app
from .chat import ChatModel
from .eval import Evaluator
from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.1.5" __version__ = "0.5.2"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -0,0 +1,4 @@
from .app import create_app
__all__ = ["create_app"]

View File

@@ -1,36 +1,68 @@
import uvicorn import asyncio
from fastapi import FastAPI, HTTPException import json
from fastapi.middleware.cors import CORSMiddleware import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse from typing import Any, Dict, Sequence
from typing import List, Tuple
from llmtuner.tuner import get_infer_args from pydantic import BaseModel
from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel from ..chat import ChatModel
from llmtuner.api.protocol import ( from ..data import Role as DataRole
Role, from ..extras.misc import torch_gc
Finish, from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
ModelCard, from .protocol import (
ModelList, ChatCompletionMessage,
ChatMessage,
DeltaMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
Finish,
Function,
FunctionCall,
ModelCard,
ModelList,
Role,
ScoreEvaluationRequest,
ScoreEvaluationResponse,
) )
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory async def lifespan(app: "FastAPI"): # collects GPU memory
yield yield
torch_gc() torch_gc()
def create_app(chat_model: ChatModel) -> FastAPI: def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
@@ -41,87 +73,163 @@ def create_app(chat_model: ChatModel) -> FastAPI:
allow_headers=["*"], allow_headers=["*"],
) )
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
role_mapping = {
Role.USER: DataRole.USER,
Role.ASSISTANT: DataRole.ASSISTANT,
Role.SYSTEM: DataRole.SYSTEM,
Role.FUNCTION: DataRole.FUNCTION,
Role.TOOL: DataRole.OBSERVATION,
}
@app.get("/v1/models", response_model=ModelList) @app.get("/v1/models", response_model=ModelList)
async def list_models(): async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER: if not chat_model.can_generate:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
query = request.messages[-1].content
prev_messages = request.messages[:-1] if len(request.messages) == 0:
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
prefix = prev_messages.pop(0).content
if role_mapping[request.messages[0].role] == DataRole.SYSTEM:
system = request.messages.pop(0).content
else: else:
prefix = None system = ""
history = [] if len(request.messages) % 2 == 0:
if len(prev_messages) % 2 == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
input_messages = []
for i, message in enumerate(request.messages):
input_messages.append({"role": role_mapping[message.role], "content": message.content})
if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = predict(query, history, prefix, request) if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat( responses = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
) )
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
)
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length, prompt_tokens=prompt_length,
completion_tokens=response_length, completion_tokens=response_length,
total_tokens=prompt_length+response_length total_tokens=prompt_length + response_length,
) )
choice_data = ChatCompletionResponseChoice( return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
index=0,
message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason=Finish.STOP
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) def stream_chat_completion(
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): ):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield jsonify(chunk)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
): ):
if len(new_text) == 0: if len(new_text) == 0:
continue continue
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
delta=DeltaMessage(content=new_text),
finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
delta=DeltaMessage(),
finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield jsonify(chunk)
yield "[DONE]" yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_score, request)
def get_score(request: ScoreEvaluationRequest):
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)
return app return app
if __name__ == "__main__": if __name__ == "__main__":
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)

View File

@@ -1,33 +1,48 @@
import time import time
from enum import Enum from enum import Enum, unique
from pydantic import BaseModel, Field
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
class Role(str, Enum): class Role(str, Enum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
@unique
class Finish(str, Enum): class Finish(str, Enum):
STOP = "stop" STOP = "stop"
LENGTH = "length" LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: Optional[str] = "model" object: Literal["model"] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner" owned_by: Literal["owner"] = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = []
class ModelList(BaseModel): class ModelList(BaseModel):
object: Optional[str] = "list" object: Literal["list"] = "list"
data: Optional[List[ModelCard]] = [] data: List[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
function: Function
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@@ -35,30 +50,33 @@ class ChatMessage(BaseModel):
content: str content: str
class DeltaMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
tools: Optional[list] = []
do_sample: bool = True
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
n: Optional[int] = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: Optional[bool] = False stream: bool = False
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatCompletionMessage
finish_reason: Finish finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None finish_reason: Optional[Finish] = None
@@ -69,17 +87,30 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Optional[str] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Optional[str] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
class ScoreEvaluationRequest(BaseModel):
model: str
messages: List[str]
max_length: Optional[int] = None
class ScoreEvaluationResponse(BaseModel):
id: Literal["scoreeval-default"] = "scoreeval-default"
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]

View File

@@ -1 +1,4 @@
from llmtuner.chat.stream_chat import ChatModel from .chat_model import ChatModel
__all__ = ["ChatModel"]

View File

@@ -0,0 +1,169 @@
from dataclasses import dataclass
from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
def _process_args(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.to_dict()
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
if max_new_tokens:
generating_args.pop("max_length", None)
generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
return gen_kwargs, prompt_length
@torch.inference_mode()
def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List[Response]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(
Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length",
)
)
return results
@torch.inference_mode()
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
yield from streamer
@torch.inference_mode()
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=True,
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(self.model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores

View File

@@ -1,102 +0,0 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template
from llmtuner.tuner import load_model_and_tokenizer
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
class ChatModel:
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments"
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.generating_args = generating_args
def process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
inputs = self.tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
))
if max_length:
gen_kwargs.pop("max_new_tokens", None)
gen_kwargs["max_length"] = max_length
if max_new_tokens:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = max_new_tokens
return gen_kwargs, prompt_length
@torch.inference_mode()
def chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response_length = len(outputs)
return response, (prompt_length, response_length)
@torch.inference_mode()
def stream_chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
yield from streamer

View File

@@ -0,0 +1,6 @@
from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]

View File

@@ -0,0 +1,131 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from .utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT, "content": old_response})
content = []
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
content.append(examples[dataset_attr.prompt][i])
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER, "content": "\n".join(content)})
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
else:
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION,
dataset_attr.system_tag: Role.SYSTEM,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
aligned_messages = []
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
}
)
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)

View File

@@ -0,0 +1,154 @@
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool{format_prompt}.\n"
"```\n"
)
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
items = (
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
)
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
action_match = re.search(regex, content)
if not action_match:
return content
tool_name = action_match.group(1).strip()
tool_input = action_match.group(2).strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return tool_name, json.dumps(arguments, ensure_ascii=False)
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default"
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
raise NotImplementedError
@dataclass
class EmptyFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
if isinstance(slot, str):
for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class FunctionFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
function = json.loads(content)
name = function["name"]
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except Exception:
name, arguments = "", ""
elements = []
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@dataclass
class ToolFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
if not len(tools):
return [""]
if self.tool_format == "default":
return [default_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [""]
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
if self.tool_format == "default":
return default_tool_extractor(content)
else:
raise NotImplementedError

191
src/llmtuner/data/loader.py Normal file
View File

@@ -0,0 +1,191 @@
import inspect
import os
from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
logger = get_logger(__name__)
def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
):
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "file":
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.file_sha1)
else:
raise NotImplementedError
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs,
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
return align_dataset(dataset, dataset_attr, data_args)
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset",
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset

119
src/llmtuner/data/parser.py Normal file
View File

@@ -0,0 +1,119 @@
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
""" extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
""" columns """
system: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
""" columns for the sharegpt format """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
tag_names = (
"role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_list.append(dataset_attr)
return dataset_list

View File

@@ -0,0 +1,269 @@
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from .utils import Role
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
from .template import Template
logger = get_logger(__name__)
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
continue
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(
"labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
)
)
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@@ -0,0 +1,626 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role, infer_max_len
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from .formatter import Formatter
logger = get_logger(__name__)
@dataclass
class Template:
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
force_system: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids(
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts elements to token ids.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
templates: Dict[str, Template] = {}
def _register_template(
name: str,
format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None,
format_system: Optional["Formatter"] = None,
format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
default_system: Optional[str] = "",
stop_words: Optional[List[str]] = [],
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False,
force_system: Optional[bool] = False,
) -> None:
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
default_system=default_system,
stop_words=stop_words,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
)
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
is_added = tokenizer.eos_token_id is None
is_oov = eos_token not in tokenizer.get_vocab()
tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token))
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if is_oov:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
else:
template = templates.get(name, None)
if templates is None:
raise ValueError("Template {} does not exist.".format(name))
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words:
tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
return template
_register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
),
)
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words=["</s>"],
efficient_eos=True,
)
_register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
_register_template(
name="baichuan2",
format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
efficient_eos=True,
)
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
force_system=True,
)
_register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
force_system=True,
)
_register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
stop_words=["<|EOT|>"],
efficient_eos=True,
)
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
stop_words=["<eoa>"],
efficient_eos=True,
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed "
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文."
),
stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
)
_register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
efficient_eos=True,
)
_register_template(
name="starchat",
format_user=StringFormatter(
slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
force_system=True,
)
_register_template(
name="vanilla",
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
)
_register_template(
name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
default_system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
),
)
_register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
stop_words=["<|End|>"],
)
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<eod>"],
replace_eos=True,
)
_register_template(
name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",
)
_register_template(
name="ziya",
format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
format_separator=EmptyFormatter(slots=["\n"]),
)

View File

@@ -0,0 +1,68 @@
import hashlib
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len
return max_source_len, max_target_len
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@@ -1,3 +0,0 @@
from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset
from llmtuner.dsets.utils import split_dataset

View File

@@ -1,116 +0,0 @@
import os
import hashlib
from typing import TYPE_CHECKING, List, Optional
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> "Dataset":
max_samples = data_args.max_samples
all_datasets: List["Dataset"] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name
data_files = None
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_files = None
elif dataset_attr.load_from == "file":
data_path = None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
dataset = load_dataset(
data_path,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None
)
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
for column_name in ["prompt", "query", "response", "history"]: # align datasets
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix
features = None
if data_args.streaming:
features = dataset.features
features["prefix"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
else:
raise ValueError("Unknown mixing strategy.")

View File

@@ -1,180 +0,0 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
from itertools import chain
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template
if TYPE_CHECKING:
from datasets import Dataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset(
dataset: "Dataset",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> "Dataset":
column_names = list(dataset.column_names)
template = get_template(data_args.template)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
prefix = examples["prefix"][i] if "prefix" in examples else None
yield query, response, history, prefix
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.max_source_length
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length
for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], []
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length]
model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs
def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
skip_special_tokens=False)
))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=column_names,
**kwargs
)
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
print_function(next(iter(dataset)))
return dataset

View File

@@ -1,15 +0,0 @@
from typing import TYPE_CHECKING, Dict
if TYPE_CHECKING:
from datasets import Dataset
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
if do_train:
if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@@ -0,0 +1,4 @@
from .evaluator import Evaluator
__all__ = ["Evaluator"]

View File

@@ -0,0 +1,123 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import inspect
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm, trange
from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import dispatch_model, load_model_and_tokenizer
from .template import get_eval_template
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
mapping = cached_file(
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
)
with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject,
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token,
**kwargs,
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = (
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
)
messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
)
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(messages[-1]["content"])
for i in trange(
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
):
batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device)
preds = self.batch_inference(batch_input)
outputs += preds
corrects = np.array(outputs) == np.array(labels)
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join(
[
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items()
if len(category_correct)
]
)
print(score_info)
if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False)
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
if __name__ == "__main__":
evaluator = Evaluator()
evaluator.eval()

View File

@@ -0,0 +1,67 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
) -> List[Dict[str, str]]:
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages
eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template
register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: ",
prefix=" ",
)
register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n",
)

View File

@@ -1,71 +1,153 @@
import os
import json import json
import os
import time import time
from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME
from .logging import get_logger
from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
self.runner = runner self.runner = runner
self.in_training = False
self.start_time = time.time() self.start_time = time.time()
self.tracker = {} self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
def timing(self):
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
""" """
self.start_time = time.time() if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step Event called at the end of training.
might take several inputs.
""" """
if self.runner is not None and self.runner.aborted: if state.is_local_process_zero:
control.should_epoch_stop = True self.in_training = False
control.should_training_stop = True self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
""" """
if self.runner is not None and self.runner.aborted: if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_world_process_zero: if not state.is_local_process_zero:
return return
cur_time = time.time() logs = dict(
cur_steps = state.log_history[-1].get("step") current_steps=self.cur_steps,
elapsed_time = cur_time - self.start_time total_steps=self.max_steps,
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 loss=state.log_history[-1].get("loss", None),
remaining_steps = state.max_steps - cur_steps eval_loss=state.log_history[-1].get("eval_loss", None),
remaining_time = remaining_steps * avg_time_per_step predict_loss=state.log_history[-1].get("predict_loss", None),
self.tracker = { reward=state.log_history[-1].get("reward", None),
"current_steps": cur_steps, learning_rate=state.log_history[-1].get("learning_rate", None),
"total_steps": state.max_steps, epoch=state.log_history[-1].get("epoch", None),
"loss": state.log_history[-1].get("loss", None), percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
"eval_loss": state.log_history[-1].get("eval_loss", None), elapsed_time=self.elapsed_time,
"predict_loss": state.log_history[-1].get("predict_loss", None), remaining_time=self.remaining_time,
"reward": state.log_history[-1].get("reward", None), )
"learning_rate": state.log_history[-1].get("learning_rate", None), if self.runner is not None:
"epoch": state.log_history[-1].get("epoch", None), logger.info(
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
"elapsed_time": str(timedelta(seconds=int(elapsed_time))), logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
"remaining_time": str(timedelta(seconds=int(remaining_time))) )
} )
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n") f.write(json.dumps(logs) + "\n")
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()

View File

@@ -1,47 +1,866 @@
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DATA_CONFIG = "dataset_info.json"
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text",
}
IGNORE_INDEX = -100 IGNORE_INDEX = -100
VALUE_HEAD_FILE_NAME = "value_head.bin" LAYERNORM_NAMES = {"norm", "ln"}
FINETUNING_ARGS_NAME = "finetuning_args.json" LOG_FILE_NAME = "trainer_log.jsonl"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
SUPPORTED_MODELS = { PEFT_METHODS = ["lora"]
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b", SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b", SUPPORTED_MODELS = OrderedDict()
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf", TRAINING_STAGES = {
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf", "Supervised Fine-Tuning": "sft",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", "Reward Modeling": "rm",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", "PPO": "ppo",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", "DPO": "dpo",
"BLOOM-560M": "bigscience/bloom-560m", "Pre-Training": "pt",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
} }
DEFAULT_MODULE = { V_HEAD_WEIGHTS_NAME = "value_head.bin"
"LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj", V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value", class DownloadSource(str, Enum):
"Baichuan": "W_pack", DEFAULT = "hf"
"InternLM": "q_proj,v_proj" MODELSCOPE = "ms"
}
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
module: Optional[str] = None,
template: Optional[str] = None,
) -> None:
prefix = None
for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path
if module is not None:
DEFAULT_MODULE[prefix] = module
if template is not None:
DEFAULT_TEMPLATE[prefix] = template
register_model_group(
models={
"Baichuan-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
},
"Baichuan-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
},
"Baichuan-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
module="W_pack",
template="baichuan",
)
register_model_group(
models={
"Baichuan2-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
},
"Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
},
"Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
},
"Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
},
},
module="W_pack",
template="baichuan2",
)
register_model_group(
models={
"BLOOM-560M": {
DownloadSource.DEFAULT: "bigscience/bloom-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
},
"BLOOM-3B": {
DownloadSource.DEFAULT: "bigscience/bloom-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
},
"BLOOM-7B1": {
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
},
},
module="query_key_value",
)
register_model_group(
models={
"BLOOMZ-560M": {
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
},
"BLOOMZ-3B": {
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
},
"BLOOMZ-7B1-mt": {
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
},
},
module="query_key_value",
)
register_model_group(
models={
"BlueLM-7B-Base": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
},
"BlueLM-7B-Chat": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
},
},
template="bluelm",
)
register_model_group(
models={
"ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
module="query_key_value",
template="chatglm2",
)
register_model_group(
models={
"ChatGLM3-6B-Base": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
},
"ChatGLM3-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
},
},
module="query_key_value",
template="chatglm3",
)
register_model_group(
models={
"ChineseLLaMA2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
},
"ChineseLLaMA2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
},
"ChineseLLaMA2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
},
"ChineseLLaMA2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
},
"ChineseLLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
},
"ChineseLLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
},
},
template="llama2_zh",
)
register_model_group(
models={
"DeepSeek-LLM-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
},
"DeepSeek-LLM-67B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
},
"DeepSeek-LLM-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
},
"DeepSeek-LLM-67B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
},
"DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
},
"DeepSeek-Math-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
},
"DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeek-MoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
},
template="deepseek",
)
register_model_group(
models={
"DeepSeekCoder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
},
"DeepSeekCoder-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
},
"DeepSeekCoder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
},
"DeepSeekCoder-6.7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
},
"DeepSeekCoder-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
},
"DeepSeekCoder-33B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
},
},
template="deepseekcoder",
)
register_model_group(
models={
"Falcon-7B": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
},
"Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
},
"Falcon-180B": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
},
"Falcon-7B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
},
"Falcon-40B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
},
"Falcon-180B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
},
},
module="query_key_value",
template="falcon",
)
register_model_group(
models={
"InternLM-7B": {
DownloadSource.DEFAULT: "internlm/internlm-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
},
"InternLM-20B": {
DownloadSource.DEFAULT: "internlm/internlm-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
},
"InternLM-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
},
"InternLM-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
},
},
template="intern",
)
register_model_group(
models={
"InternLM2-7B": {
DownloadSource.DEFAULT: "internlm/internlm2-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
},
"InternLM2-20B": {
DownloadSource.DEFAULT: "internlm/internlm2-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
},
"InternLM2-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
},
"InternLM2-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
},
},
module="wqkv",
template="intern2",
)
register_model_group(
models={
"LingoWhale-8B": {
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
}
},
module="qkv_proj",
)
register_model_group(
models={
"LLaMA-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
},
"LLaMA-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
},
"LLaMA-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
},
"LLaMA-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
},
}
)
register_model_group(
models={
"LLaMA2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
},
"LLaMA2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
},
"LLaMA2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
},
"LLaMA2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
},
"LLaMA2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
},
"LLaMA2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
},
},
template="llama2",
)
register_model_group(
models={
"Mistral-7B": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
},
template="mistral",
)
register_model_group(
models={
"Mixtral-8x7B": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
},
template="mistral",
)
register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
}
},
template="openchat",
)
register_model_group(
models={
"Orion-14B-Base": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
},
"Orion-14B-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
},
"Orion-14B-Long-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
},
"Orion-14B-RAG-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
},
"Orion-14B-Plugin-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
},
},
template="orion",
)
register_model_group(
models={
"Phi-1.5-1.3B": {
DownloadSource.DEFAULT: "microsoft/phi-1_5",
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
},
}
)
register_model_group(
models={
"Qwen-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
},
"Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
"Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
},
"Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
},
"Qwen-1.8B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
},
"Qwen-1.8B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
},
"Qwen-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
},
"Qwen-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
},
"Qwen-14B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
},
"Qwen-14B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
},
"Qwen-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
},
"Qwen-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
},
},
module="c_attn",
template="qwen",
)
register_model_group(
models={
"Qwen1.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
},
"Qwen1.5-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
},
"Qwen1.5-4B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
},
"Qwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
},
"Qwen1.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
},
"Qwen1.5-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
},
"Qwen1.5-4B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
},
"Qwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
},
"Qwen1.5-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
},
"Qwen1.5-0.5B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
},
"Qwen1.5-1.8B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
},
"Qwen1.5-1.8B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
},
"Qwen1.5-4B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
},
"Qwen1.5-4B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
},
"Qwen1.5-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
},
"Qwen1.5-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
},
"Qwen1.5-14B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
},
"Qwen1.5-14B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
},
"Qwen1.5-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
},
"Qwen1.5-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
},
},
template="qwen",
)
register_model_group(
models={
"SOLAR-10.7B": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Chat": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group(
models={
"Skywork-13B-Base": {
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
}
}
)
register_model_group(
models={
"Vicuna1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
},
"Vicuna1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
},
},
template="vicuna",
)
register_model_group(
models={
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
},
template="xuanyuan",
)
register_model_group(
models={
"XVERSE-7B": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
},
"XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
},
"XVERSE-7B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
},
"XVERSE-13B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
},
"XVERSE-65B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
},
template="xverse",
)
register_model_group(
models={
"Yayi-7B": {
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
},
"Yayi-13B": {
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
},
},
template="yayi",
)
register_model_group(
models={
"Yi-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
},
"Yi-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
},
"Yi-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
},
"Yi-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
},
"Yi-6B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
},
"Yi-34B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
},
},
template="yi",
)
register_model_group(
models={
"Yuan2-2B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
},
"Yuan2-51B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
},
"Yuan2-102B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
},
},
template="yuan",
)
register_model_group(
models={
"Zephyr-7B-Alpha-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
},
"Zephyr-7B-Beta-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
},
template="zephyr",
)

View File

@@ -1,13 +1,19 @@
import sys
import logging import logging
import sys
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.log = "" self.log = ""
def reset(self):
self.log = ""
def emit(self, record): def emit(self, record):
if record.name == "httpx": if record.name == "httpx":
return return
@@ -16,19 +22,12 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n" self.log += "\n\n"
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
datefmt="%m/%d/%Y %H:%M:%S"
) )
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter) handler.setFormatter(formatter)
@@ -38,3 +37,12 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger
def reset_logging() -> None:
r"""
Removes basic config of root logger. (unused in script)
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))

View File

@@ -1,19 +1,45 @@
import gc
import os
from typing import TYPE_CHECKING, Dict, Tuple
import torch import torch
from typing import TYPE_CHECKING, List, Optional, Tuple from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.generation.utils import LogitsProcessorList from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from transformers.generation.logits_process import LogitsProcessor from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
except Exception:
_is_bf16_available = False
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
class AverageMeter: class AverageMeter:
r""" r"""
Computes and stores the average and current value. Computes and stores the average and current value.
""" """
def __init__(self): def __init__(self):
self.reset() self.reset()
@@ -30,22 +56,6 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
# Avoids runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r""" r"""
Returns the number of trainable parameters and number of all parameters in the model. Returns the number of trainable parameters and number of all parameters in the model.
@@ -68,74 +78,121 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 def fix_valuehead_checkpoint(
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
def prepare_model_for_training( ) -> None:
model: "PreTrainedModel", r"""
finetuning_type: str, The model is already unwrapped.
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
for name, param in model.named_parameters(): There are three cases:
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
param.data = param.data.to(torch.float32) 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
if use_gradient_checkpointing: We assume `stage3_gather_16bit_weights_on_model_save=true`.
if hasattr(model, "enable_input_require_grads"): """
model.enable_input_require_grads() if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
return
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
else: else:
def make_inputs_require_grad(module, input, output): decoder_state_dict[name.replace("pretrained_model.", "")] = param
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable() os.remove(path_to_checkpoint)
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if finetuning_type != "full" and hasattr(model, output_layer_name): if safe_serialization:
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"): save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728) else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
output_layer: torch.nn.Linear = getattr(model, output_layer_name) logger.info("Value head model saved at: {}".format(output_dir))
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor: def get_current_device() -> torch.device:
return super().forward(x.to(input_dtype)).to(torch.float32) r"""
Gets the current available device.
"""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
device = "cpu"
setattr(model, output_layer_name, CastOutputToFloat(output_layer)) return torch.device(device)
return model
def get_device_count() -> int:
return torch.cuda.device_count()
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""
Collects GPU memory. Collects GPU memory.
""" """
gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": def try_download_model_from_ms(model_args: "ModelArguments") -> None:
r""" if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
Dispatches a pre-trained model to GPUs with balanced memory. return
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None: try:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") from modelscope import snapshot_download
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} revision = "master" if model_args.model_revision == "main" else model_args.model_revision
max_memory = get_balanced_memory(model, **kwargs) model_args.model_name_or_path = snapshot_download(
# Make sure tied weights are tied before creating the device map. model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
model.tie_weights() )
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) except ImportError:
return dispatch_model(model, device_map) raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
return model.cuda()
def use_modelscope() -> bool:
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))

View File

@@ -0,0 +1,53 @@
import importlib.metadata
import importlib.util
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def _get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except Exception:
return "0.0.0"
def is_fastapi_availble():
return _is_package_available("fastapi")
def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
def is_jieba_available():
return _is_package_available("jieba")
def is_matplotlib_available():
return _is_package_available("matplotlib")
def is_nltk_available():
return _is_package_available("nltk")
def is_requests_available():
return _is_package_available("requests")
def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_starlette_available():
return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

View File

@@ -0,0 +1,197 @@
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_torch_attn_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
dropout_rate = self.attention_dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def apply_llama_patch() -> None:
LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward

View File

@@ -0,0 +1,38 @@
import torch
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
def patch_mixtral_replace_moe_impl() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -1,11 +1,16 @@
import os
import math
import json import json
import matplotlib.pyplot as plt import math
import os
from typing import List, Optional from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger from .logging import get_logger
from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -17,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
""" """
last = scalars[0] last = scalars[0]
smoothed = list() smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars: for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val) smoothed.append(smoothed_val)
@@ -26,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)

View File

@@ -1,49 +0,0 @@
import os
import torch
from typing import Dict
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
return True
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
if not os.path.exists(valuehead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
return True

View File

@@ -1,262 +0,0 @@
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
@dataclass
class Template:
prefix: str
prompt: str
sep: str
use_history: bool
def get_prompt(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "",
eos_token: Optional[str] = "</s>"
) -> str:
r"""
Returns a string containing prompt without response.
"""
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
def get_dialog(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
r"""
Returns a list containing prompt-response pairs.
"""
result = self._format_example(query, history, prefix)
result[-1][-1] = resp
return result
def _format_example(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else []
history = history + [(query, "")]
return [
[(self.sep if i else prefix) + self.prompt.format(query=q), r]
for i, (q, r) in enumerate(history)
]
@dataclass
class Llama2Template(Template):
def _format_example(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix if prefix.startswith("<<SYS>>") else "<<SYS>>\n{}\n<</SYS>>\n\n".format(prefix)
history = history if (history and self.use_history) else []
history = history + [(query, "")]
return [
[(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r]
for i, (q, r) in enumerate(history)
]
templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
template_class = Llama2Template if name == "llama2" else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
sep=sep,
use_history=use_history
)
def get_template(name: str) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
return template
r"""
Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix="",
prompt="{query}",
sep="",
use_history=False
)
r"""
Default template.
"""
register_template(
name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
"""
register_template(
name="llama2",
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt="[INST] {query} [/INST] ",
sep="<s>",
use_history=True
)
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
register_template(
name="alpaca",
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
register_template(
name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="",
use_history=True
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
)
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
)
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_template(
name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="",
use_history=True
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True
)

View File

@@ -1,5 +1,18 @@
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .general_args import GeneralArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args
__all__ = [
"DataArguments",
"EvaluationArguments",
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"get_eval_args",
"get_infer_args",
"get_train_args",
]

View File

@@ -1,130 +1,98 @@
import os
import json
from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
def __post_init__(self):
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass @dataclass
class DataArguments: class DataArguments:
""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field(
metadata={"help": "Which template to use for constructing prompts in training and inference."} template: Optional[str] = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."},
) )
dataset: Optional[str] = field( dataset: Optional[str] = field(
default="alpaca_en", default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
default="data", default="data",
metadata={"help": "The name of the folder containing datasets."} metadata={"help": "Path to the folder containing the datasets."},
) )
split: Optional[str] = field( split: Optional[str] = field(
default="train", default="train",
metadata={"help": "Which dataset split to use for training and evaluation."} metadata={"help": "Which dataset split to use for training and evaluation."},
)
cutoff_len: Optional[int] = field(
default=1024,
metadata={"help": "The cutoff length of the model inputs after tokenization."},
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."},
) )
streaming: Optional[bool] = field( streaming: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Enable streaming mode."} metadata={"help": "Enable dataset streaming."},
) )
buffer_size: Optional[int] = field( buffer_size: Optional[int] = field(
default=16384, default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."} metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
) )
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing."} metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
)
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."} metadata={"help": "Overwrite the cached training and evaluation sets."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the preprocessing."} metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total input sequence length after tokenization."}
)
max_target_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total output sequence length after tokenization."}
) )
max_samples: Optional[int] = field( max_samples: Optional[int] = field(
default=None, default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
) )
eval_num_beams: Optional[int] = field( eval_num_beams: Optional[int] = field(
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
) )
ignore_pad_token_for_loss: Optional[bool] = field( ignore_pad_token_for_loss: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} metadata={
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
},
) )
source_prefix: Optional[str] = field( val_size: Optional[float] = field(
default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
)
dev_ratio: Optional[float] = field(
default=0, default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
)
sft_packing: Optional[bool] = field(
default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
)
cache_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the preprocessed datasets."},
) )
def init_for_training(self): # support mixing multiple datasets def __post_init__(self):
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.reserved_label_len >= self.cutoff_len:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
dataset_info = json.load(f)
if self.source_prefix is not None: if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
prefix_list = self.source_prefix.split("|") raise ValueError("Streaming mode should have an integer val size.")
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
self.dataset_list: List[DatasetAttr] = [] if self.streaming and self.max_samples is not None:
for i, name in enumerate(dataset_names): raise ValueError("`max_samples` is incompatible with `streaming`.")
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)

View File

@@ -0,0 +1,48 @@
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field(
metadata={"help": "Name of the evaluation task."},
)
task_dir: Optional[str] = field(
default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."},
)
batch_size: Optional[int] = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."},
)
seed: Optional[int] = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."},
)
lang: Optional[Literal["en", "zh"]] = field(
default="en",
metadata={"help": "Language used at evaluation."},
)
n_shot: Optional[int] = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."},
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."},
)
download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."},
)
def __post_init__(self):
if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.")

View File

@@ -1,79 +1,218 @@
import json import json
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Literal, Optional
@dataclass @dataclass
class FinetuningArguments: class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
""" """
Arguments pertaining to which techniques we are going to fine-tuning with.
""" name_module_trainable: Optional[str] = field(
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( default=None,
default="lora", metadata={
metadata={"help": "Which fine-tuning method to use."} "help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
) Use commas to separate multiple modules. \
num_hidden_layers: Optional[int] = field( Use "all" to specify all the available modules. \
default=32, LLaMA choices: ["mlp", "self_attn"], \
metadata={"help": "Number of decoder blocks in the model. \ BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ Qwen choices: ["mlp", "attn"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ InternLM2 choices: ["feed_forward", "attention"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \ Others choices: the same as LLaMA."""
Falcon choices: [\"32\", \"60\"], \ },
Baichuan choices: [\"32\", \"40\"]"}
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."} metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
) )
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( use_llama_pro: Optional[bool] = field(
default="mlp", default=False,
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."},
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \ )
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"}
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
additional_target: Optional[str] = field(
default=None,
metadata={
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
},
)
lora_alpha: Optional[int] = field(
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
)
lora_dropout: Optional[float] = field(
default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
) )
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
default=8, default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
)
lora_dropout: Optional[float] = field(
default=0.1,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
) )
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default="q_proj,v_proj", default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ metadata={
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ "help": """Name(s) of target modules to apply LoRA. \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ Use commas to separate multiple modules. \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} Use "all" to specify all the available modules. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
Others choices: the same as LLaMA."""
},
)
lora_bf16_mode: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
)
use_rslora: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
)
create_new_adapter: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
)
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."},
)
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_ftx: Optional[float] = field(
default=0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
ppo_buffer_size: Optional[int] = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
)
ppo_epochs: Optional[int] = field(
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."},
)
ppo_target: Optional[float] = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
)
ppo_whiten_rewards: Optional[bool] = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
)
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
)
ref_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reference model."},
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."},
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reward model used for the PPO training."},
)
reward_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reward model."},
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."},
)
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
disable_version_checking: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable version checking."},
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
) )
def __post_init__(self): def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA def split_arg(arg):
self.lora_target = [target.strip() for target in self.lora_target.split(",")] if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 self.name_module_trainable = split_arg(self.name_module_trainable)
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)] self.lora_alpha = self.lora_alpha or self.lora_rank * 2
else: # fine-tuning the first n layers if num_layer_trainable < 0 self.lora_target = split_arg(self.lora_target)
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] self.additional_target = split_arg(self.additional_target)
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
if self.use_llama_pro and self.finetuning_type != "freeze":
raise ValueError("`use_llama_pro` is only valid for the Freeze method.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f: with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string) f.write(json_string)
@classmethod @classmethod
def load_from_json(cls, json_path: str): def load_from_json(cls, json_path: str):
"""Creates an instance from the content of `json_path`.""" r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
return cls(**json.loads(text)) return cls(**json.loads(text))

View File

@@ -1,13 +0,0 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
"""
Arguments pertaining to which stage we are going to perform.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)

View File

@@ -1,51 +1,56 @@
from typing import Any, Dict, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
@dataclass @dataclass
class GeneratingArguments: class GeneratingArguments:
""" r"""
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
) )
temperature: Optional[float] = field( temperature: Optional[float] = field(
default=0.95, default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."} metadata={"help": "The value used to modulate the next token probabilities."},
) )
top_p: Optional[float] = field( top_p: Optional[float] = field(
default=0.7, default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
},
) )
top_k: Optional[int] = field( top_k: Optional[int] = field(
default=50, default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
) )
num_beams: Optional[int] = field( num_beams: Optional[int] = field(
default=1, default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."} metadata={"help": "Number of beams for beam search. 1 means no beam search."},
) )
max_length: Optional[int] = field( max_length: Optional[int] = field(
default=None, default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
) )
max_new_tokens: Optional[int] = field( max_new_tokens: Optional[int] = field(
default=512, default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
) )
repetition_penalty: Optional[float] = field( repetition_penalty: Optional[float] = field(
default=1.0, default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
) )
length_penalty: Optional[float] = field( length_penalty: Optional[float] = field(
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
args = asdict(self) args = asdict(self)
if args.get("max_new_tokens", None): if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None) args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
return args return args

View File

@@ -1,72 +1,142 @@
import torch from dataclasses import asdict, dataclass, field
from typing import Literal, Optional from typing import Any, Dict, Literal, Optional
from dataclasses import dataclass, field
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
""" """
model_name_or_path: str = field( model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
use_fast_tokenizer: Optional[bool] = field( use_fast_tokenizer: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
) )
use_auth_token: Optional[bool] = field( resize_vocab: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."} metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
)
split_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
model_revision: Optional[str] = field( model_revision: Optional[str] = field(
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
padding_side: Optional[Literal["left", "right"]] = field(
default="left",
metadata={"help": "The side on which the model should have padding applied."}
) )
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model."} metadata={"help": "The number of bits to quantize the model."},
) )
quantization_type: Optional[Literal["fp4", "nf4"]] = field( quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4", default="nf4",
metadata={"help": "Quantization data type to use in int4 training."} metadata={"help": "Quantization data type to use in int4 training."},
) )
double_quantization: Optional[bool] = field( double_quantization: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."} metadata={"help": "Whether or not to use double quantization in int4 training."},
) )
compute_dtype: Optional[torch.dtype] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."} metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
checkpoint_dir: Optional[str] = field( flash_attn: Optional[bool] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."} metadata={"help": "Enable FlashAttention-2 for faster training."},
)
shift_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
disable_gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
)
upcast_lmhead_output: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: Optional[int] = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
print_param_status: Optional[bool] = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
) )
def __post_init__(self): def __post_init__(self):
if self.checkpoint_dir is not None: # support merging multiple lora weights self.compute_dtype = None
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.model_max_length = None
if self.quantization_bit is not None: if self.split_special_tokens and self.use_fast_tokenizer:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@@ -0,0 +1,271 @@
import logging
import os
import sys
from typing import Any, Dict, Optional, Tuple
import datasets
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
logger = get_logger(__name__)
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _check_dependencies(disabled: bool) -> None:
if disabled:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if unknown_args:
print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if (
training_args.do_train
and finetuning_args.finetuning_type == "freeze"
and finetuning_args.name_module_trainable is None
):
raise ValueError("Please specify `name_module_trainable` in Freeze training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info(
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
)
)
if (
finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None
):
logger.warning(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
)
# Post-process model arguments
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info(
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
str(model_args.compute_dtype),
)
)
logger.info(f"Training/evaluation parameters {training_args}")
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@@ -0,0 +1,5 @@
from .loader import load_model_and_tokenizer
from .utils import dispatch_model, load_valuehead_params
__all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"]

View File

@@ -0,0 +1,149 @@
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
from .utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.adapter_name_or_path is None:
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.num_layer_trainable
)
)
stride = num_layers // finetuning_args.num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
freeze_modules = {"all"}
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
if module_name not in freeze_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
adapter_to_merge = model_args.adapter_name_or_path
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model)
else:
target_modules = finetuning_args.lora_target
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
}
if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model

View File

@@ -0,0 +1,132 @@
from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
try_download_model_from_ms(model_args)
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**config_kwargs,
)
patch_tokenizer(tokenizer)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
model = None
if is_trainable and model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = {
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
}
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
if model_args.adapter_name_or_path:
model_args.adapter_name_or_path = None
logger.warning("Unsloth does not support loading adapters.")
if model is None:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs,
)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False)
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
logger.info(
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
)
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
if model_args.print_param_status:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
return model, tokenizer

View File

@@ -0,0 +1,326 @@
import math
import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
if model_args.flash_attn:
if is_flash_attn2_available():
config_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention2 is not installed.")
config_kwargs["attn_implementation"] = None
else:
config_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
def _configure_longlora(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
quantization_config["use_exllama"] = False # disable exllama
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
config_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
logger.info("Upcasting layernorm weights in float32.")
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(model_args, config_kwargs)
if model_args.rope_scaling is not None:
_configure_rope(config, model_args, is_trainable)
if is_trainable and model_args.shift_attn:
_configure_longlora(config)
_configure_quantization(config, tokenizer, model_args, config_kwargs)
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
) -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer)
if is_trainable:
_prepare_model_for_training(model, model_args)
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if is_trainable:
patch_mixtral_replace_moe_impl()
try:
model.add_model_tags(["llama-factory"])
except Exception:
logger.warning("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
self.pretrained_model.tie_weights()
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))

113
src/llmtuner/model/utils.py Normal file
View File

@@ -0,0 +1,113 @@
import inspect
from typing import TYPE_CHECKING, Dict, List
import torch
from transformers import PreTrainedModel
from transformers.utils import cached_file
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import ModelArguments
logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
"""
if getattr(model, "quantization_method", None): # already set on current device
return model
if (
torch.cuda.device_count() > 1
and isinstance(model, PreTrainedModel)
and model._no_split_modules is not None
and model.config.model_type != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs)
else:
return model.to(device=get_current_device())
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora.
"""
quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None:
linear_cls = torch.nn.Linear
elif quantization_method == "bitsandbytes":
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
module_names = set()
for name, module in model.named_modules():
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

View File

@@ -0,0 +1,4 @@
from .tuner import export_model, run_exp
__all__ = ["export_model", "run_exp"]

View File

@@ -0,0 +1,4 @@
from .workflow import run_dpo
__all__ = ["run_dpo"]

View File

@@ -0,0 +1,54 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
import torch
from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, prompt_len + answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append(
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len),
}
)
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@@ -0,0 +1,148 @@
from collections import defaultdict
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers import PreTrainedModel
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
**kwargs,
):
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
self.precompute_ref_log_probs = False
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self.beta = beta
self.label_smoothing = 0
self.loss_type = loss_type
self.ftx_gamma = ftx_gamma
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model(
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics

View File

@@ -0,0 +1,84 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model_and_tokenizer
from ...train.dpo.collator import DPODataCollatorWithPadding
from ...train.dpo.trainer import CustomDPOTrainer
from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING:
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Create reference model
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
model=model,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args),
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -0,0 +1,4 @@
from .workflow import run_ppo
__all__ = ["run_ppo"]

View File

@@ -0,0 +1,375 @@
import math
import os
import sys
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs,
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict(),
)
self.state = TrainerState()
self.control = TrainerControl()
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled:
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
if resume_from_checkpoint is not None:
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
total_train_batch_size = (
self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps
* self.finetuning_args.ppo_buffer_size
* self.args.world_size
)
if self.args.max_steps > 0:
num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = sys.maxsize
max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * len_dataloader)
steps_in_epoch = len_dataloader
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
)
)
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps))
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
try:
batch = next(dataiter)
except StopIteration:
dataiter = iter(self.dataloader)
batch = next(dataiter)
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
self.model.eval()
# Get inputs
self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
self.model.train()
# Run PPO step
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.config.log_with is not None:
try:
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)
except Exception:
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / steps_in_epoch, 2),
)
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control)
loss_meter.reset()
reward_meter.reset()
if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.log_callback.on_train_end(self.args, self.state, self.control)
self.save_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
else:
response_length = response_index[-1].item() + 1
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
return queries, responses
@torch.no_grad()
def get_rewards(
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead",
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
Both inputs and outputs are put on CPU.
"""
if self.finetuning_args.reward_model_type == "api":
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return get_rewards_from_server(self.reward_model, messages)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
reward_model = self.reward_model
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None,
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
Subclass and override to inject custom behavior.
"""
if self.args.should_save:
try:
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
except ValueError:
logger.warning(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights"
)
self._save(output_dir, state_dict={})
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model.save_checkpoint(output_dir)

View File

@@ -0,0 +1,59 @@
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available
if TYPE_CHECKING:
from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
if is_requests_available():
import requests
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers)
rewards = json.loads(response.text)["scores"]
return torch.Tensor(rewards)
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.v_head.summary.weight, model.v_head.summary.bias]
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
layer_norm_params = {}
for name, param in model.named_parameters():
if param.data.dtype == torch.float32:
layer_norm_params[name] = param.data.detach().clone()
param.data = param.data.to(model.config.torch_dtype)
return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
for name, param in model.named_parameters():
if name in layernorm_params:
param.data = layernorm_params[name]

View File

@@ -0,0 +1,107 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.ppo.trainer import CustomPPOTrainer
from ...train.utils import create_ref_model, create_reward_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_ppo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, training_args.do_train, add_valuehead=True
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
)
# Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
# Initialize our Trainer
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
# Training
if training_args.do_train:
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
ppo_trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -0,0 +1,4 @@
from .workflow import run_pt
__all__ = ["run_pt"]

View File

@@ -1,19 +1,20 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForSeq2Seq
from transformers import DataCollatorForLanguageModeling, Trainer
from ...data import get_dataset, split_dataset
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.utils import create_modelcard_and_push
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_pt( def run_pt(
@@ -21,35 +22,30 @@ def run_pt(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = None,
): ):
dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = Trainer(
finetuning_args=finetuning_args,
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args),
) )
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train() train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
trainer.save_model() if trainer.is_world_process_zero() and finetuning_args.plot_loss:
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
@@ -61,6 +57,8 @@ def run_pt(
perplexity = float("inf") perplexity = float("inf")
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -0,0 +1,4 @@
from .workflow import run_rm
__all__ = ["run_rm"]

View File

@@ -1,8 +1,11 @@
import torch from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
@@ -16,7 +19,11 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
the last n examples represent rejected examples. the last n examples represent rejected examples.
""" """
features = [ features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])} {
for key in ("accept_ids", "reject_ids") for feature in features "input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
}
for key in ("chosen_ids", "rejected_ids")
for feature in features
] ]
return super().__call__(features) return super().__call__(features)

View File

@@ -1,6 +1,7 @@
import numpy as np
from typing import Dict, Sequence, Tuple, Union from typing import Dict, Sequence, Tuple, Union
import numpy as np
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds preds, _ = eval_preds

View File

@@ -0,0 +1,99 @@
import json
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.trainer import PredictionOutput
logger = get_logger(__name__)
class PairwiseTrainer(Trainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
# Split the inputs and rewards into two parts, chosen and rejected
batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0:
end_index = chosen_length
div_index = end_index - 1
else:
end_index = max(chosen_length, rejected_length)
div_index = check_divergence[0]
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size
if return_outputs:
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
return loss, [loss, chosen_scores, rejected_scores]
return loss
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
writer.write("\n".join(res))

View File

@@ -0,0 +1,79 @@
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, List, Optional
from transformers import Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.rm.collator import PairwiseDataCollatorWithPadding
from ...train.rm.metric import compute_accuracy
from ...train.rm.trainer import PairwiseTrainer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_rm(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, training_args.do_train, add_valuehead=True
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = PairwiseTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks + [FixValueHeadModelCallback()],
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args),
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -0,0 +1,4 @@
from .workflow import run_sft
__all__ = ["run_sft"]

View File

@@ -1,16 +1,24 @@
import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba import numpy as np
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba
if is_nltk_available():
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
if is_rouge_available():
from rouge_chinese import Rouge
@dataclass @dataclass
class ComputeMetrics: class ComputeMetrics:
@@ -25,7 +33,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics. Uses the model predictions to compute metrics.
""" """
preds, labels = eval_preds preds, labels = eval_preds
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -49,6 +57,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4)) score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()} return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@@ -0,0 +1,100 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
logger = get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
r"""
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len]
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
"""
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
)
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
if len(pad_len):
preds[i] = np.concatenate(
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
) # move pad token to last
decoded_labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for label, pred in zip(decoded_labels, decoded_preds):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))

View File

@@ -0,0 +1,101 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from ...data import get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Override the decoding parameters of Seq2SeqTrainer
training_args_dict = training_args.to_dict()
training_args_dict.update(
dict(
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams,
)
)
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args),
)
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -0,0 +1,91 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from transformers import PreTrainedModel
from ..extras.callbacks import LogCallback
from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model_and_tokenizer
from .dpo import run_dpo
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None):
model_args, _, finetuning_args, _ = get_infer_args(args)
if model_args.export_dir is None:
raise ValueError("Please specify `export_dir`.")
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.")
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
raise ValueError("Cannot merge adapters to a quantized model.")
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if getattr(model, "quantization_method", None):
model = model.to("cpu")
elif hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", torch.float16)
model.save_pretrained(
save_directory=model_args.export_dir,
max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format),
)
if model_args.export_hub_model_id is not None:
model.push_to_hub(
model_args.export_hub_model_id,
token=model_args.hf_hub_token,
max_shard_size="{}GB".format(model_args.export_size),
safe_serialization=(not model_args.export_legacy_format),
)
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except Exception:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
run_exp()

120
src/llmtuner/train/utils.py Normal file
View File

@@ -0,0 +1,120 @@
from typing import TYPE_CHECKING, Optional, Union
import torch
from ..extras.logging import get_logger
from ..hparams import FinetuningArguments, ModelArguments
from ..model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
logger = get_logger(__name__)
def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
kwargs = {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory", finetuning_args.finetuning_type],
}
if not training_args.do_train:
pass
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(
dict(
model_name_or_path=finetuning_args.ref_model,
adapter_name_or_path=finetuning_args.ref_model_adapters,
quantization_bit=finetuning_args.ref_model_quantization_bit,
)
)
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info("Use reward server {}".format(finetuning_args.reward_model))
return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer(
"default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
)
model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(
dict(
model_name_or_path=finetuning_args.reward_model,
adapter_name_or_path=finetuning_args.reward_model_adapters,
quantization_bit=finetuning_args.reward_model_quantization_bit,
)
)
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model

View File

@@ -1,5 +0,0 @@
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo

Some files were not shown because too many files have changed in this diff Show More