125 Commits

Author SHA1 Message Date
hiyouga
6eed1db36c Release v0.1.7
Former-commit-id: 81abe8d6cabaa1ebe74dc32a5dc143389e4c9f31
2023-08-18 17:21:27 +08:00
hiyouga
948124f55e tiny fix
Former-commit-id: 0ee159654ac6339c162745b004e2152ba6fe3c81
2023-08-18 13:07:35 +08:00
hiyouga
2b191ca776 support ppo score norm (trl 0.5.1.dev required)
Former-commit-id: 2b25db6d260ec1532281a592e873579346c7d21c
2023-08-18 12:02:42 +08:00
hiyouga
be4d2822ea fix PPO trainer #551 , update readme
Former-commit-id: faead74849470cebae9e37cde5fab2a71b32aa43
2023-08-18 11:43:10 +08:00
hiyouga
736ddd0319 update readme
Former-commit-id: beaf2fb737dbe64d35334d88b42935c89ef09eee
2023-08-18 01:51:55 +08:00
hiyouga
dfa289aa72 Update .gitignore
Former-commit-id: a1772a4dfef8dfaf7c2c321fad0a70ccf95fe6a0
2023-08-18 01:43:42 +08:00
hiyouga
c2644f939a update training resuming
Former-commit-id: 2ec75c31f609e65116ac3b621eeb7d8ccbf69135
2023-08-18 01:41:17 +08:00
hoshi-hiyouga
f11c1ae562 Merge pull request #434 from niuba/main
add last_checkpoint support

Former-commit-id: b78d461f2826c194c332ead37825704c2cb8b910
2023-08-18 01:38:31 +08:00
hoshi-hiyouga
3126164aa6 Merge branch 'main' into main
Former-commit-id: 870d2c7bf74d0da5a927bef4b8b01d15cc66a3e9
2023-08-18 01:37:23 +08:00
hiyouga
ed10486cad support bf16 ppo #551
Former-commit-id: 092088967de7409a2d51847cfc7afc83a8887320
2023-08-18 00:40:32 +08:00
hiyouga
04fa430c6c fix ChatGLM2 ppo #527 #528
Former-commit-id: 60d6ad64d7c9f6445b0df8de0153c3a311974198
2023-08-18 00:34:59 +08:00
hiyouga
fa1893b59c fix generation bug #532
Former-commit-id: c071121e67374e5f09798db57cfc8668617a36ae
2023-08-17 22:21:34 +08:00
hiyouga
e993e717a5 fix streaming in pt stage #548 #549
Former-commit-id: 050e992bee2a9293cc7399b578de807b5bf9bddc
2023-08-17 17:59:26 +08:00
hiyouga
c80e56423a update readme
Former-commit-id: b74af3c9cf29e1690ae4d5acb27599b1abd152e2
2023-08-17 11:00:22 +08:00
hiyouga
ffa09a01d6 fix baichuan and intern template
Former-commit-id: e1fd18fa6ef1009f978aca5210a259251a0b19a6
2023-08-17 01:27:20 +08:00
hiyouga
7d04f8567b fix generation
Former-commit-id: 66a0300d312ef91c24fcf80667fa3b0bb8e1a342
2023-08-16 22:39:54 +08:00
hiyouga
baa709674f fix system prompt
Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
2023-08-16 01:35:52 +08:00
hiyouga
ca9a494d0c fix baichuan template #481
Former-commit-id: 7608c6c25877d97ef26a1c209c4073c9c42f4535
2023-08-15 11:38:21 +08:00
hoshi-hiyouga
37eb8c05cc Merge pull request #516 from liuyanyi/add_gitignore
[Enhance] Add .gitignore file

Former-commit-id: 12cfe5482f5ef95d8c386d0af0de381e72eab0f9
2023-08-15 11:25:40 +08:00
hiyouga
7c046edb7b fix ChatGLM RLHF
Former-commit-id: 4e43e887e432ceb7e9287b4e309b63af3c3ba1bf
2023-08-15 11:19:20 +08:00
Yanyi Liu
22cea38b20 Add .gitignore
Former-commit-id: a2ebdeef81706596617da4409fc5da71739bccdc
2023-08-15 11:13:45 +08:00
hiyouga
ef2ca0a827 alert pad_token source
Former-commit-id: f26a84e0d927d2554890daf431a93652e18f4235
2023-08-15 00:07:56 +08:00
hiyouga
7f0b908de2 update webui
Former-commit-id: da30d0fb4abdb825f3383ddd106bb06a84695b7a
2023-08-14 22:45:26 +08:00
hoshi-hiyouga
5fc5e776ff Merge pull request #511 from hiyouga/feature-autoTemplate
add template match and stage in webui

Former-commit-id: 413752ecba845cddaff5fb48db7d3d24b960eec1
2023-08-14 22:44:04 +08:00
codemayq
93b281c016 auto match template when change model_name
Former-commit-id: ab2d7ab0572765ce33a52ac71641062d5d904db4
2023-08-14 20:56:05 +08:00
codemayq
9585699918 add template match and stage in webui
Former-commit-id: d6283e7f041f08f76d18350cb5f6a6c58ca80e92
2023-08-14 20:42:59 +08:00
hiyouga
bceaba551d fix ChatGLM lm_head #494
Former-commit-id: bf0048abdaeb2b9592d38ac991704ad014370b47
2023-08-14 14:14:48 +08:00
hiyouga
0bfeed3a7e fix bug in webui
Former-commit-id: c95f0f687689934379b6c24abf872ffcde06073b
2023-08-14 11:38:42 +08:00
hiyouga
70a780c3c0 fix webui cache
Former-commit-id: 9aba5c197fbc8abaab77f454374f8b497f0310d0
2023-08-14 11:37:01 +08:00
hiyouga
d74ab5306c update readme_zh
Former-commit-id: bdfe7e0285fdeb3a2728669dbdabf70c9652735c
2023-08-14 11:13:25 +08:00
hiyouga
688e8601ab web UI integrating RLHF
Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
2023-08-14 10:48:47 +08:00
hiyouga
4933ab5956 fix #480
Former-commit-id: ec15ca8fffacba2c34e1849c5ce90ca9989d66a2
2023-08-14 00:23:56 +08:00
hiyouga
6c7225a5d4 fix webui
Former-commit-id: 2c8b7414be9b43e20cc1d0575cc4dc1c7545fd86
2023-08-12 23:52:07 +08:00
hiyouga
a22982f2fa tiny fix
Former-commit-id: 50a34c043de6d9e1410291e1d8c1ea9d53754e9e
2023-08-12 22:02:43 +08:00
hiyouga
c95479dddb fix rope scaling
Former-commit-id: 2e0dd36700ec5e8294581c1db4b9431f755fc5f8
2023-08-12 22:00:01 +08:00
hiyouga
fc48bd8da0 update readme
Former-commit-id: 94ac570cb62aa9cd5dba105f0bb4c4da43eca042
2023-08-12 21:29:06 +08:00
hiyouga
d5323bfa3f update readme
Former-commit-id: ecfe87f34b383901f8e97ffb90af459cd55419b1
2023-08-12 21:25:19 +08:00
hiyouga
e9d4a2b507 update readme
Former-commit-id: eadbe9b7a0b6c8897e7a763b519cc5b7e00f3b2c
2023-08-12 21:23:05 +08:00
hiyouga
37bcbe8046 update readme
Former-commit-id: 6fa381400c21fa249cebcdff8c3afd72f8de20b3
2023-08-12 21:00:11 +08:00
hiyouga
fdfb644f0a support rope scaling, fix #475 #476 #478
Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
2023-08-12 20:46:27 +08:00
hoshi-hiyouga
cde9f3db57 Merge pull request #479 from hiyouga/feature-addCmdExport
add sft script preview in webui

Former-commit-id: 060225e57d13d8164beb6920410c181fbb28b77a
2023-08-12 20:41:52 +08:00
codemayq
8bf5a98815 add sft script preview in webui
Former-commit-id: 2b72649b404750226aa418b61ef5a6c9ac03938f
2023-08-12 13:53:55 +08:00
hiyouga
be566a15a5 fix unusual output of 8bit models #278 #391
Former-commit-id: 337ce5272b81f5561162beb08814b0e5abf23703
2023-08-12 00:25:29 +08:00
hiyouga
d5f1b99ac4 Release v0.1.6
Former-commit-id: 43c8b3c3c8bfb2e32d17fb3e8b194938e37d54bd
2023-08-11 23:25:57 +08:00
hiyouga
2144bb0e27 Update README_zh.md
Former-commit-id: 4fc154bcf039ba3f9240213158df757881cf3579
2023-08-11 14:06:02 +08:00
hiyouga
bc665bacc7 add defaults
Former-commit-id: 4636d3bbe6b984ca93e3a80ae5239f3ddda461bd
2023-08-11 13:56:26 +08:00
hiyouga
52bfcf4883 fix stop word in baichuan template
Former-commit-id: cba5ac9cfc81f11b97831998ea15def5e0b487c2
2023-08-11 13:51:46 +08:00
hiyouga
06df3d6fb6 fix baichuan template
Former-commit-id: b1681fe35346381cda613297f1cbb710f0a6daa6
2023-08-11 13:45:47 +08:00
hiyouga
ca719a8697 support DPO training (2305.18290)
Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
2023-08-11 03:02:53 +08:00
hoshi-hiyouga
72dfd74005 Merge pull request #451 from jovialchen/main
huggingface login for projects must login while running

Former-commit-id: 246ac241277908909b81cdf85fec1f24449dbae9
2023-08-10 17:25:38 +08:00
hiyouga
69302c4420 fix webui val size
Former-commit-id: 490c067d4e0828832e0ebdb704a9207dc974b15b
2023-08-10 15:20:44 +08:00
jiongxuc
42d7019b2e huggingface login for projects must login while running
Former-commit-id: 0a4a2a1d3e0ff1f57215512d294d782080bd383c
2023-08-10 14:57:12 +08:00
hiyouga
5f0d0d6b9b fix template
Former-commit-id: e3967eb1cdd8d19e8afee9ba52e7eb7d6cd86129
2023-08-09 23:14:27 +08:00
hiyouga
76cb63e4f6 fix template
Former-commit-id: 907e8cd86fbd4cdfa26dad21ceaf6e01d8fe37e4
2023-08-09 23:10:20 +08:00
hiyouga
467d571206 support val set in streaming mode
Former-commit-id: faed15b58ed00b1e09bb091e7eee48f5ef7c508b
2023-08-09 23:00:26 +08:00
hiyouga
972bfa700a fix tokenizer
Former-commit-id: 7849587cd4e149291d08edef9a528a1bad796c7e
2023-08-09 17:52:15 +08:00
niuba
458955d0fb add last_checkpoint support
Former-commit-id: 9f1977e4de00b14a9d1b555c25bcaf12998d5046
2023-08-09 16:39:27 +08:00
hiyouga
990eeccf45 fix sft trainer
Former-commit-id: 08cc888b1569572d0cd20bcf3f07e20072a0311a
2023-08-09 16:35:03 +08:00
hiyouga
a3a7465f00 fix rm #420, fix template #426, fix #423
Former-commit-id: 70ea3caaa7a7695c77179cd1bb18707a80a373d7
2023-08-09 16:23:31 +08:00
hoshi-hiyouga
031a819257 fix llama2 template
Former-commit-id: 6c74f726d4e672f5a1a57df201c27c1f697384f0
2023-08-09 00:58:27 +08:00
hoshi-hiyouga
eb4b4e3c8c fix tokenizer
Former-commit-id: fa463ef279b596d5d53cc169831f51b42031fc05
2023-08-09 00:54:54 +08:00
hiyouga
d2e1fe9b1d update webui
Former-commit-id: 343a4cd82b07a40f96ba413d1d991419ff07a24a
2023-08-09 00:26:11 +08:00
hiyouga
6e27a9e39a fix tokenizer #417
Former-commit-id: 01aa678311bfd213a4b410a4e0ff09f48a0d40a1
2023-08-08 23:59:41 +08:00
hiyouga
805478c911 fix bug
Former-commit-id: 0dff1d951f1a9fe05a74d334bf477b55c7c64199
2023-08-08 21:28:28 +08:00
hiyouga
a281cdeb89 fix bug
Former-commit-id: c13ce66021b21e015871b84489eeafa127a424a4
2023-08-08 17:55:55 +08:00
hiyouga
cda698a67f fix chatml template #408
Former-commit-id: 21e0cc3f44c35ae689b00b274391492f413725ac
2023-08-08 17:44:39 +08:00
hiyouga
15acd17716 update args spec
Former-commit-id: a006068346edda6e2851b23d2005fdb218a7287d
2023-08-07 15:23:35 +08:00
hiyouga
34a2bddfcd update readme
Former-commit-id: 06bcbb901f69265632892a5fcbc956b8be1153da
2023-08-07 15:02:02 +08:00
hiyouga
370f817549 Merge branch 'main' of https://github.com/hiyouga/LLaMA-Efficient-Tuning
Former-commit-id: 5c5657227db285048e3850631badb040eea9b6ca
2023-08-07 13:59:16 +08:00
hiyouga
041390c37e fix #376
Former-commit-id: a5b01257ba3323bcb2dd0217fb89a387e39ddbec
2023-08-07 13:58:59 +08:00
hoshi-hiyouga
d9fe4bf500 Merge pull request #382 from hiyouga/feature-updateReadme
add detailed model configs

Former-commit-id: 371c50cf3fd4e3f5e8fb390508c27cb5f18fa531
2023-08-07 13:43:38 +08:00
hiyouga
e0c7e944fc update trainer
Former-commit-id: 0d39b53a5164e34d22fe0a492eaa0d7ac63102fe
2023-08-07 13:34:35 +08:00
codemayq
0845fe67db add detailed model configs
Former-commit-id: 438c43f820e39738eaa1c296aadcf6d141c3289f
2023-08-07 09:30:23 +08:00
hiyouga
fe3b12d900 fix qwen eos token
Former-commit-id: 770830c67886f5872b39b9608949ec62d4616b27
2023-08-06 13:31:17 +08:00
hiyouga
a70d56864e fix qwen tokenizer #361
Former-commit-id: 78a2fa95c8ab669254a6c8fce8138c4395fb0a09
2023-08-05 17:06:05 +08:00
hiyouga
fdbb2c5378 fix template for tiktoken
Former-commit-id: 8328447f81eb5b90310df08cf2928c83ef6355fe
2023-08-05 13:42:42 +08:00
hiyouga
3c0aaf42af remove redundant code
Former-commit-id: dcec1717592107ba9d26eb2ac520309da19d1805
2023-08-05 00:27:27 +08:00
hiyouga
438e19160a fix template
Former-commit-id: b88200a88ea112e043dc44058606805c60e32844
2023-08-05 00:25:00 +08:00
hiyouga
f2b2ff6950 fix llama2 template
Former-commit-id: 08f37145e0bca5f1a8fd7bad01c64dc69b07361b
2023-08-05 00:07:54 +08:00
hoshi-hiyouga
86cef96305 Support safe ChatML template, fix qwen tok #351 #354
https://github.com/openai/openai-python/blob/main/chatml.md
Former-commit-id: 94bfc9d85f7cef3a5eb15085e0124a424373814f
2023-08-05 00:00:23 +08:00
hiyouga
5f50944baf fix bos and eos token
Former-commit-id: ab386f4c0fb5eaac24264a5bbef4c03deeb92158
2023-08-04 23:55:57 +08:00
hiyouga
0804fd2353 fix encode
Former-commit-id: ec382abd906d93cf78c7fbaec753ce6bcf8cfebd
2023-08-04 23:27:55 +08:00
hiyouga
86419eb457 support chatml safe encoding
Former-commit-id: ea52bb135bf9d07738091006ec7ada8df14cf15e
2023-08-04 23:14:28 +08:00
hiyouga
76f3ae7bf3 support interleave probs
Former-commit-id: 168d99816f9bdc746c587f7f09753ba7e0a4b19d
2023-08-04 21:27:35 +08:00
hiyouga
aaa85190eb fix webui export model
Former-commit-id: c34469c05e681239db23e2e666b5ac6a4e38aba9
2023-08-04 14:20:27 +08:00
hiyouga
e2a4e926b9 fix mtloader
Former-commit-id: ca48c2c02c3cfa9afb99971b50daeda9cf14e7cb
2023-08-03 19:29:02 +08:00
hiyouga
d6e922dc1c tiny fix
Former-commit-id: 81ef7017a4c96441951adeff0276cc5ab76a3544
2023-08-03 17:42:28 +08:00
hiyouga
27f4317ec6 fix qwen inference
Former-commit-id: 823f0de0ca0a92b6f48a90e5ffe57a48dc018f1d
2023-08-03 16:31:55 +08:00
hiyouga
e434348216 fix qwen inference
Former-commit-id: 2c5fe45ce1405124f12ecd20e263b5538af97972
2023-08-03 16:15:38 +08:00
hiyouga
2e19afedb8 support Qwen-7B, fix InternLM-7B inference
Former-commit-id: 25d2ca29ecb70cbfd5206333c667042a0c4d2e5a
2023-08-03 15:53:32 +08:00
hiyouga
da08fa7c63 update web demo
Former-commit-id: 5b6ad9adb665096bfb36dc90789a1d4a16345122
2023-08-03 13:28:28 +08:00
hiyouga
9c96b97dc7 fix webui
Former-commit-id: e87630ef77977b2879f1199b9a421acbbbb32a51
2023-08-03 12:43:12 +08:00
hiyouga
28a51b622b modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
2023-08-02 23:17:36 +08:00
hiyouga
8bd1da7144 fix PPO trainer
Former-commit-id: 21982a7d4dd9b7c3a1145b481f02b9990e32dc00
2023-08-02 19:10:23 +08:00
hiyouga
e4d0b8ee6e update ppo trainer
Former-commit-id: c27136a83e167465d3f825e40f10c7b9fcfbf97a
2023-08-02 18:46:41 +08:00
hiyouga
1dfb28b362 fix memory leak of PPO trainer
Former-commit-id: 38410894a5ebf0b043b55a6bd5cca3cd0a44b27d
2023-08-02 17:41:34 +08:00
hiyouga
ba618947e7 release v0.1.5
Former-commit-id: d619e76bc4098c29a7fdc05f5a71208bd1079c9f
2023-08-02 16:10:31 +08:00
hoshi-hiyouga
f81041b502 Merge pull request #307 from GitYCC/feature/fix-llama2-prompt-template
[feature] Fix template of Llama2 to match the offical template

Former-commit-id: a750b1f1ed16e20233df4d2f1c20507122919f5a
2023-08-02 15:51:28 +08:00
YC Chen
f2533a2800 [fix] Remove useless code
Former-commit-id: 077e1556112913e4eeef47e581055183b39d5404
2023-08-02 14:35:35 +08:00
YC Chen
bb5b4a7f26 [feature] Fix template of Llama2 to match the offical template
Former-commit-id: 1a98d45aefd95eea3768fb93e5a9da257ec61181
2023-08-02 14:10:15 +08:00
hiyouga
20bff87021 fix bug in preprocessing
Former-commit-id: 94952894576dfc4b42118162aec9aa35c3503c40
2023-08-02 01:10:28 +08:00
hiyouga
722b954800 update readme
Former-commit-id: 5154a04869be8c47e591351565b7842339fb99e4
2023-08-01 18:48:27 +08:00
hiyouga
19256086c7 fix #296
Former-commit-id: 69e9ed9b96a7cfb3d3b43ec5ddd01aa0bfd9b784
2023-08-01 18:43:53 +08:00
hiyouga
250fecfcd4 Fix #294
Former-commit-id: 09762d9849655f5e6c71b9472d55b42489dd944b
2023-08-01 18:13:03 +08:00
hiyouga
cb4d1d5ebb restore from git lfs
Former-commit-id: 0c734a37113b773ae7c0bc8b8d1af39b15bc0fb2
2023-08-01 16:33:25 +08:00
hiyouga
d7d557fb2e Update .gitattributes
Former-commit-id: 92e68f9f30c2fc91ae1b40865bc5c2d94899ba22
2023-08-01 16:28:54 +08:00
hiyouga
0b8e19b6a6 fix webui
Former-commit-id: cf4cd52d36894f53a6ec45d003f887771012e5b4
2023-08-01 12:11:37 +08:00
hiyouga
8e26eb374e fix RM save model
Former-commit-id: 8104cc2425431eb1cddccf3909855296116f922b
2023-08-01 11:56:17 +08:00
hiyouga
9bba01a033 use git lfs
Former-commit-id: 4886d0071751f68c5a2d926bd9fcee0c93337322
2023-08-01 10:14:08 +08:00
hiyouga
661890b8a1 release v0.1.4
Former-commit-id: 81f84aaf2e120e39edb28ef42893939fc9a184e2
2023-08-01 10:08:47 +08:00
hiyouga
772ad4ec6b fix inference
Former-commit-id: 55dc2bdd3eaa552c655e584fc3cbbf017c7bc3e7
2023-08-01 00:06:48 +08:00
hiyouga
6f65f8cb3b fix arg check
Former-commit-id: 2c5c73de9ebc88e2d04e80754781c94a571133a0
2023-07-31 23:48:57 +08:00
hiyouga
43e83548b9 update readme
Former-commit-id: d99cda254e5025ff3f968d256197ab031bfabef1
2023-07-31 23:42:32 +08:00
hiyouga
dd3f3e9749 support streaming data, fix #284 #274 #268
Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
2023-07-31 23:33:00 +08:00
hiyouga
124f61b404 Update data_args.py
Former-commit-id: 41ac5455af195747ba369c3a6dc7d412a366d54d
2023-07-28 17:42:41 +08:00
hiyouga
e8748cc6f3 update readme
Former-commit-id: 14d20cd1fdcfd1f2842362f70472b666e5d48c7d
2023-07-28 17:36:00 +08:00
hiyouga
fafec8b7a5 fix #268
Former-commit-id: 1eee0207fb370bb9e234e9bd3f9a0c47d7d01bc9
2023-07-28 17:02:26 +08:00
hiyouga
030daca686 update dataset
Former-commit-id: 4a044aabbd19c92a9ae93c1c30536f5086fd47f9
2023-07-26 17:05:12 +08:00
hiyouga
ac587438f8 fix #242
Former-commit-id: 80a346e29beb49e8935b786e2af1059fdc4954b2
2023-07-25 17:04:02 +08:00
hiyouga
c145bbef3c update dataset
Former-commit-id: 4fc2c3293d91d8464527ebd1ddabe572c8355616
2023-07-23 20:01:43 +08:00
hiyouga
745c46ee04 Update README_zh.md
Former-commit-id: 9d3c8803a34c06a2a5512fec3f841d7efcab3e3c
2023-07-22 14:31:16 +08:00
hiyouga
a707f5b502 update readme, fix web ui postprocess
Former-commit-id: ba51ab3379100108f7b52a3c2444ccdd99e8a6ef
2023-07-22 14:29:22 +08:00
hoshi-hiyouga
dc2e801077 Merge pull request #221 from mrhan1993/main
根据GLM Efficient Tuning添加中文README,web添加了server_port参数

Former-commit-id: 948f2abcb818211ee99d4c140e26044ca591369f
2023-07-22 13:04:25 +08:00
NULL
b56d5108b2 Merge branch 'hiyouga:main' into main
Former-commit-id: 8244c5b554ad0823e8ebea3d5583a6ecf9a66d2d
2023-07-21 17:00:26 +08:00
mrhan1993
8e6b7034fe 根据GLM Efficient Tuning添加中文README,web添加了server_port
Former-commit-id: 29e3acd23eafd891667d7a860ec544a5b05d3c33
2023-07-21 16:57:58 +08:00
70 changed files with 3288 additions and 1278 deletions

160
.gitignore vendored Normal file
View File

@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

295
README.md
View File

@@ -8,17 +8,31 @@
👋 Join our [WeChat](assets/wechat.jpg).
\[ English | [中文](README_zh.md) \]
## Changelog
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model.
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model.
[23/07/31] Now we support **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
@@ -28,39 +42,48 @@
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
- [InternLM](https://github.com/InternLM/InternLM) (7B)
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- |----------|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
| PPO Training | | | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- Use `--quantization_bit 4/8` argument to enable QLoRA.
## Provided Datasets
- For pre-training:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- For supervised fine-tuning:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
@@ -68,7 +91,6 @@
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -77,12 +99,13 @@
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward modelling:
- For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -100,18 +123,13 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- sentencepiece and tiktoken
- jieba, rouge-chinese and nltk (used at evaluation)
- gradio and matplotlib (used in web_demo.py)
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
And **powerful GPUs**!
If you want to enable quantized LoRA (QLoRA) on the Windows platform, you should install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
## Getting Started
### Data Preparation (optional)
@@ -130,21 +148,35 @@ cd LLaMA-Efficient-Tuning
pip install -r requirements.txt
```
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### All-in-one Web UI
```bash
python src/train_web.py
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
### (Continually) Pre-Training
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
Currently the web UI only supports training on **a single GPU**.
### Train on a single GPU
#### Pre-Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset wiki_demo \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
@@ -158,15 +190,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
### Supervised Fine-Tuning
#### Supervised Fine-Tuning
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
@@ -180,36 +214,43 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
### Reward Model Training
#### Reward Modeling
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--learning_rate 1e-6 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### PPO Training (RLHF)
#### PPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
@@ -220,18 +261,45 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--plot_loss
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### Distributed Training
#### Use Huggingface Accelerate
```bash
accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above)
```
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
<details><summary>Example config.yaml for training with DeepSpeed ZeRO-2</summary>
```yaml
compute_environment: LOCAL_MACHINE
@@ -259,14 +327,96 @@ use_cpu: false
</details>
#### Use DeepSpeed
```bash
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
... # arguments (same as above)
```
<details><summary>Example ds_config.json for training with DeepSpeed ZeRO-2</summary>
```json
{
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### Export model
```bash
python src/export_model.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
### API Demo
```bash
python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Evaluation (BLEU and ROUGE_CHINESE)
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_eval \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
@@ -282,9 +432,10 @@ We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_predict \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
@@ -293,46 +444,11 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
If you want to predict the samples with empty responses, please kindly fill the `response` column with **dummy tokens** to ensure the sample will not be discarded throughout the preprocessing phase.
## TODO
### API Demo
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Export model
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
- [ ] Implementing multi-query attention for faster inference.
- [ ] Supporting full-parameter RLHF training.
## License
@@ -344,8 +460,11 @@ Please follow the model licenses to use the corresponding model weights:
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## Citation

487
README_zh.md Normal file
View File

@@ -0,0 +1,487 @@
# LLaMA Efficient Tuning
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Efficient-Tuning?style=social)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Efficient-Tuning)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
👋 加入我们的[微信群](assets/wechat.jpg)。
\[ [English](README.md) | 中文 \]
## 更新日志
[23/08/18] 现在我们支持了**训练状态恢复**,请将 `transformers` 升级至 `4.31.0` 以启用此功能。
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型时请添加 `--template chatml` 参数。
[23/07/31] 现在我们支持了**数据流式加载**。请尝试使用 `--streaming``--max_steps 10000` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。使用 InternLM-chat 模型时请添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数。
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数。
## 模型
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- |----------|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
- 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用对应的模板。
## 训练方法
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- 使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
## 数据集
- 用于预训练:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
使用方法请参考 [data/README.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
```bash
pip install --upgrade huggingface_hub
huggingface-cli login
```
## 软件依赖
- Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- sentencepiece 和 tiktoken
- jieba, rouge-chinese 和 nltk (用于评估)
- gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API)
以及 **强而有力的 GPU**
## 如何使用
### 数据准备(可跳过)
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`
### 环境搭建(可跳过)
```bash
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
cd LLaMA-Efficient-Tuning
pip install -r requirements.txt
```
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### 浏览器一体化界面
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
目前网页 UI 仅支持**单卡训练**。
### 单 GPU 训练
#### 预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset wiki_demo \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 奖励模型训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-6 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
```
#### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate
```bash
accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 Accelerate 配置示例</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
#### 使用 DeepSpeed
```bash
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
... # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 DeepSpeed 配置示例</summary>
```json
{
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### 导出微调后的模型
```bash
python src/export_model.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
### API 服务
```bash
python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
关于 API 文档请见 `http://localhost:8000/docs`
### 命令行测试
```bash
python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 浏览器测试
```bash
python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 指标评估BLEU 分数和汉语 ROUGE 分数)
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_llama_model \
--do_eval \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
### 模型预测
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_llama_model \
--do_predict \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
## TODO
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
- [ ] 支持 RLHF 的全参数微调。
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## 引用
如果您觉得此项目有帮助,请考虑以下列格式引用
```bibtex
@Misc{llama-efficient-tuning,
title = {LLaMA Efficient Tuning},
author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
year = {2023}
}
```
## 致谢
本项目是 [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) 的同类项目。采用了类似的代码结构和训练方法。
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Efficient-Tuning&type=Date)

18
data/README_zh.md Normal file
View File

@@ -0,0 +1,18 @@
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以如下格式提供您的数据集定义。
```json
"数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址若指定则忽略下列三个参数",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None"
}
}
```
其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。

View File

@@ -1 +1 @@
0a57fbc1d8cb08a8cd71c5eb8425cf59206ffed6
57fd080be5bffe4153fe3ee26a175e3d56da30f3

View File

@@ -1 +0,0 @@
56405bb8f52727e52e99693739494b9b7b0d7ba6

View File

@@ -1 +0,0 @@
fa935248a5d40d2bdd5649af99a72a754d40ae7a

View File

@@ -3,8 +3,10 @@ transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.21.0
peft>=0.4.0
trl>=0.4.7
trl>=0.5.0
scipy
sentencepiece
tiktoken
jieba
rouge-chinese
nltk

View File

@@ -1,19 +1,13 @@
# coding=utf-8
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
import uvicorn
from llmtuner import ChatModel
from llmtuner.api.app import create_app
from llmtuner.tuner import get_infer_args
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel(*get_infer_args())
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
print("Visit http://localhost:8000/docs for API document.")
if __name__ == "__main__":

View File

@@ -1,13 +1,8 @@
# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel
from llmtuner.tuner import get_infer_args
def main():
chat_model = ChatModel(*get_infer_args())
chat_model = ChatModel()
history = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

View File

@@ -1,16 +1,8 @@
# coding=utf-8
# Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
from llmtuner import export_model
def main():
model_args, _, training_args, finetuning_args, _ = get_train_args()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
export_model()
if __name__ == "__main__":

View File

@@ -1,4 +1,9 @@
# Level: api, webui > chat > tuner > dsets > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel
from llmtuner.tuner import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.1.3"
__version__ = "0.1.7"

View File

@@ -0,0 +1 @@
from llmtuner.api.app import create_app

View File

@@ -5,9 +5,8 @@ from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple
from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
@@ -48,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER:
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content
system = prev_messages.pop(0).content
else:
prefix = None
system = None
history = []
if len(prev_messages) % 2 == 0:
@@ -65,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
generate = predict(query, history, prefix, request)
generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
)
usage = ChatCompletionResponseUsage(
@@ -86,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT),
@@ -96,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
):
if len(new_text) == 0:
continue
@@ -122,6 +121,6 @@ def create_app(chat_model: ChatModel) -> FastAPI:
if __name__ == "__main__":
chat_model = ChatModel(*get_infer_args())
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@@ -3,35 +3,37 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import get_template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
class ChatModel:
def __init__(
self,
model_args: ModelArguments,
data_args: DataArguments,
finetuning_args: FinetuningArguments,
generating_args: GeneratingArguments
) -> None:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.template = get_template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix or ""
self.generating_args = generating_args
self.model = dispatch_model(self.model)
self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt
def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix
system = system or self.system_prompt
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
)
input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
@@ -41,11 +43,14 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
input_ids=input_ids,
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor()
))
@@ -61,9 +66,13 @@ class ChatModel:
@torch.inference_mode()
def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
@@ -72,9 +81,13 @@ class ChatModel:
@torch.inference_mode()
def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer

View File

@@ -1,40 +1,27 @@
import os
import hashlib
from typing import List
from typing import TYPE_CHECKING, List, Union
from datasets import Dataset, concatenate_datasets, load_dataset
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, DataArguments
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
def get_dataset(
model_args: ModelArguments,
data_args: DataArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
ext2type = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
model_args: "ModelArguments",
data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
@@ -47,60 +34,59 @@ def get_dataset(
data_path = None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = ext2type.get(data_files[0].split(".")[-1], None)
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = ext2type.get(data_files[0].split(".")[-1], None)
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
checksum(data_files[0], dataset_attr.dataset_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
raw_datasets = load_dataset(
dataset = load_dataset(
data_path,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None
)
dataset = raw_datasets[data_args.split]
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
prefix_data = [dataset_attr.source_prefix] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
("query_column", "query"),
("response_column", "response"),
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
dataset = dataset.add_column("prefix", prefix_data)
for column_name in ["prompt", "query", "response", "history"]: # align datasets
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.system_prompt: # add system prompt
if data_args.streaming:
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
else:
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0]
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
else:
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
raise ValueError("Unknown mixing strategy.")

View File

@@ -1,91 +1,88 @@
from typing import Literal
import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template
from llmtuner.hparams import DataArguments
from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:
) -> Union["Dataset", "IterableDataset"]:
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
column_names = list(dataset.column_names)
prompt_template = get_template(data_args.prompt_template)
# support question with a single answer or multiple answers
def get_dialog(examples):
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
yield dialog
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
block_size = data_args.max_source_length - 1
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.max_source_length
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
for i in range(0, total_length, block_size)]
return {
"input_ids": result,
"labels": result.copy()
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
def preprocess_supervised_dataset(examples):
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []}
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length
for dialog in get_dialog(examples):
for query, response, history, system in construct_example(examples):
input_ids, labels = [], []
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_unsupervised_dataset(examples):
# build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []}
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
for query, response, history, system in construct_example(examples):
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@@ -93,82 +90,82 @@ def preprocess_dataset(
target_ids = target_ids[:data_args.max_target_length]
model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
if len(prompt_ids) > data_args.max_source_length:
prompt_ids = prompt_ids[:data_args.max_source_length]
if len(chosen_ids) > data_args.max_target_length:
chosen_ids = chosen_ids[:data_args.max_target_length]
if len(rejected_ids) > data_args.max_target_length:
rejected_ids = rejected_ids[:data_args.max_target_length]
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
skip_special_tokens=False)
))
print("labels:\n{}".format(tokenizer.decode([
token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"]
], skip_special_tokens=False)))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
elif stage == "sft":
if not training_args.predict_with_generate:
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
else:
preprocess_function = preprocess_unsupervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
dataset = dataset.map(
preprocess_function,
batched=True,
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
if stage == "pt":
print_unsupervised_dataset_example(dataset[0])
elif stage == "sft":
print_supervised_dataset_example(dataset[0])
elif stage == "rm":
print_pairwise_dataset_example(dataset[0])
elif stage == "ppo":
print_unsupervised_dataset_example(dataset[0])
dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=column_names,
**kwargs
)
print_function(next(iter(dataset)))
return dataset

View File

@@ -1,16 +1,59 @@
from typing import Dict
from datasets import Dataset
import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def split_dataset(
dataset: Dataset, dev_ratio: float, do_train: bool
) -> Dict[str, Dataset]:
# Split the dataset
if do_train:
if dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=dev_ratio)
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@@ -1,74 +1,128 @@
import os
import json
import time
from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from transformers import TrainerCallback
from transformers.trainer_utils import has_length
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
logger = get_logger(__name__)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
self.start_time = time.time()
self.tracker = {}
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def timing(self):
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
Event called at the end of training.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
if not state.is_world_process_zero:
if not state.is_local_process_zero:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
self.tracker = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"eval_loss": state.log_history[-1].get("eval_loss", None),
"predict_loss": state.log_history[-1].get("predict_loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n")
f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()

View File

@@ -1,13 +1,23 @@
IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"]
STAGES = [
"SFT",
"Reward Modeling",
"PPO",
"DPO",
"Pre-Training"
]
SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
@@ -19,29 +29,50 @@ SUPPORTED_MODELS = {
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b",
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
"InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
}
DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj",
"ChineseLLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"InternLM": "q_proj,v_proj"
"InternLM": "q_proj,v_proj",
"Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value"
}
DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan",
"InternLM": "intern",
"Qwen": "chatml",
"ChatGLM2": "chatglm2"
}

View File

@@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler):
super().__init__()
self.log = ""
def reset(self):
self.log = ""
def emit(self, record):
if record.name == "httpx":
return
@@ -16,8 +19,16 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n"
def get_logger(name: str) -> logging.Logger:
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"

View File

@@ -1,12 +1,12 @@
import torch
from typing import List, Optional
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
class AverageMeter:
r"""
@@ -28,46 +28,43 @@ class AverageMeter:
self.avg = self.sum / self.count
# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def print_trainable_params(model: torch.nn.Module) -> None:
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
model: "PreTrainedModel",
finetuning_type: str,
output_embedding_layer_name: Optional[str] = "lm_head",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel:
) -> "PreTrainedModel":
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
@@ -83,19 +80,20 @@ def prepare_model_for_training(
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype
if finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model
def torch_gc() -> None:
r"""
Collects GPU memory.
@@ -103,3 +101,28 @@ def torch_gc() -> None:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()

View File

@@ -1,6 +1,6 @@
import os
import torch
from typing import Dict, Optional
from typing import Dict
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
@@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
state_dict = model.state_dict()
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if (not trainable_only) or v.requires_grad:
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict

View File

@@ -1,64 +1,230 @@
from typing import Dict, List, Optional, Tuple
import tiktoken
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
@dataclass
class Template:
prefix: str
prompt: str
sep: str
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
system: str
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> str:
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a string containing prompt without response.
Returns a single pair of token ids representing prompt and response respectively.
"""
return "".join(self._format_example(query, history, prefix))
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
return prompt_ids, answer_ids
def get_dialog(
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._format_example(query, history, prefix) + [resp]
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs
def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
def _format(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[str, List[Tuple[str, str]]]:
r"""
Aligns inputs to the standard format.
"""
system = system or self.system # use system if provided
history = history if (history and self.use_history) else []
history = history + [(query, "<dummy>")]
convs = []
for turn_idx, (user_query, bot_resp) in enumerate(history):
if turn_idx == 0:
convs.append(prefix + self.prompt.format(query=user_query))
convs.append(bot_resp)
history = history + [(query, resp)]
return system, history
def _get_special_ids(
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if (
tokenizer.bos_token_id is not None
and getattr(tokenizer, "add_bos_token", True)
): # baichuan-13b has no bos token
bos_ids = [tokenizer.bos_token_id]
else:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp)
return convs[:-1] # drop last
bos_ids = [] # bos token is optional
if tokenizer.eos_token_id is not None:
eos_ids = [tokenizer.eos_token_id]
else:
raise ValueError("EOS token is required.")
return bos_ids, eos_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
prefix_ids = bos_ids + prefix_ids + sep_ids
else:
prefix_ids = bos_ids
else:
prefix_ids = sep_ids + bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
) -> List[int]:
r"""
Converts context to token ids.
"""
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
token_ids = []
for elem in context:
if isinstance(elem, str):
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise NotImplementedError
return token_ids
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has no sep_ids
query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Template(
def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True
) -> None:
template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
system=system,
sep=sep,
stop_words=stop_words,
use_history=use_history
)
def get_template(name: str) -> Template:
def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
additional_special_tokens = template.stop_words
if len(template.stop_words): # inplace method
if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token)
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens))
return template
@@ -67,9 +233,12 @@ Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix="",
prompt="{query}",
sep="",
prefix=[],
prompt=[
"{{query}}"
],
system="",
sep=[],
use_history=False
)
@@ -79,11 +248,19 @@ Default template.
"""
register_template(
name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[
"\n"
]
)
@@ -94,17 +271,40 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
"""
register_template(
name="llama2",
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt=" [INST] {query} [/INST] ",
sep="</s>",
use_history=True
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
"""
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
)
@@ -114,11 +314,19 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
"""
register_template(
name="alpaca",
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[
"\n\n"
]
)
@@ -128,11 +336,17 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
"""
register_template(
name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT: "
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
@@ -141,10 +355,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
"\n\n"
]
)
@@ -153,10 +373,16 @@ Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nBot: "
],
system="",
sep=[
"\n"
]
)
@@ -165,10 +391,16 @@ Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system="",
sep=[
"\n"
]
)
@@ -177,10 +409,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
"\n"
]
)
@@ -189,11 +430,19 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_template(
name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}###Assistant: "
],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[
"###"
]
)
@@ -202,10 +451,22 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
prefix=[
"{{system}}"
],
prompt=[
"<|User|>:{{query}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
system="",
sep=[
"\n"
],
stop_words=[
"</s>", # internlm cannot replace eos token
"<eoa>"
]
)
@@ -214,10 +475,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="</s>",
use_history=True
prefix=[
"{{system}}",
{"token": "<reserved_102>"} # user token
],
prompt=[
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
stop_words=[
"<reserved_102>" # user token
]
)
@@ -227,8 +497,71 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
{"token": "<|end|>"}
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
system="",
sep=[
"\n"
],
stop_words=[
"<|end|>"
]
)
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
"""
register_template(
name="chatml",
prefix=[
{"token": "<|im_start|>"},
"system\n{{system}}",
{"token": "<|im_end|>"}
],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
system="You are a helpful assistant.",
sep=[
"\n"
],
stop_words=[
"<|im_end|>"
]
)
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"\n\n"
]
)

View File

@@ -1,6 +1,6 @@
import os
import json
from typing import List, Optional
from typing import List, Literal, Optional
from dataclasses import dataclass, field
@@ -10,25 +10,28 @@ class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None
system_prompt: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass
class DataArguments:
"""
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: str = field(
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field(
default="alpaca_zh",
default="alpaca_en",
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
@@ -39,6 +42,22 @@ class DataArguments:
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
)
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
@@ -67,17 +86,13 @@ class DataArguments:
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
)
source_prefix: Optional[str] = field(
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
)
dev_ratio: Optional[float] = field(
val_size: Optional[float] = field(
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
)
prompt_template: Optional[str] = field(
default="default",
metadata={"help": "Which template to use for constructing prompts in training and inference."}
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
)
def init_for_training(self): # support mixing multiple datasets
@@ -85,12 +100,12 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
if self.source_prefix is not None:
prefix_list = self.source_prefix.split("|")
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names):
@@ -108,12 +123,11 @@ class DataArguments:
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)

View File

@@ -5,32 +5,36 @@ from dataclasses import asdict, dataclass, field
@dataclass
class FinetuningArguments:
"""
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_hidden_layers: Optional[int] = field(
default=32,
metadata={"help": "Number of decoder blocks in the model. \
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \
Falcon choices: [\"32\", \"60\"], \
Baichuan choices: [\"32\", \"40\"]"}
Baichuan choices: [\"32\", \"40\"] \
Qwen choices: [\"32\"], \
XVERSE choices: [\"40\"]"}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"}
Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"mlp\", \"attn\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
lora_rank: Optional[int] = field(
default=8,
@@ -45,11 +49,25 @@ class FinetuningArguments:
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_target: Optional[str] = field(
default="q_proj,v_proj",
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO Training."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
def __post_init__(self):
@@ -63,17 +81,17 @@ class FinetuningArguments:
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
"""Saves the content of this instance in JSON format inside `json_path`."""
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
"""Creates an instance from the content of `json_path`."""
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@@ -4,10 +4,10 @@ from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
r"""
Arguments pertaining to which stage we are going to perform.
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)

View File

@@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
@dataclass
class GeneratingArguments:
"""
r"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass
class ModelArguments:
"""
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
@@ -43,9 +43,9 @@ class ModelArguments:
default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."}
)
compute_dtype: Optional[torch.dtype] = field(
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
metadata={"help": "Adopt scaled rotary positional embeddings."}
)
checkpoint_dir: Optional[str] = field(
default=None,
@@ -55,18 +55,33 @@ class ModelArguments:
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_auth_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
model_max_length: Optional[int] = field(
default=None,
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
)
def __post_init__(self):
if self.compute_dtype is not None or self.model_max_length is not None:
raise ValueError("These arguments cannot be specified.")
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_auth_token is not None:
from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)

View File

@@ -1,5 +1 @@
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.tune import export_model, run_exp

View File

@@ -1,7 +1,7 @@
import os
import torch
from typing import TYPE_CHECKING
from transformers.modeling_utils import PreTrainedModel
from peft import (
PeftModel,
TaskType,
@@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
def init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
is_mergeable: bool
) -> PreTrainedModel:
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -36,7 +39,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full":
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
@@ -62,7 +65,7 @@ def init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir

View File

@@ -1,26 +1,33 @@
import os
import math
import torch
from typing import Literal, Optional, Tuple
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
@@ -29,15 +36,15 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
def load_model_and_tokenizer(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
@@ -47,9 +54,6 @@ def load_model_and_tokenizer(
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
@@ -63,21 +67,67 @@ def load_model_and_tokenizer(
padding_side=model_args.padding_side,
**config_kwargs
)
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
tokenizer.pad_token_id = 0 # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
is_mergeable = True
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
if model_args.compute_dtype == torch.bfloat16:
setattr(config, "bf16", True)
else:
setattr(config, "fp16", True)
# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
if is_trainable:
logger.warning("Qwen model does not support RoPE scaling in training.")
else:
setattr(config, "use_dynamic_ntk", True)
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA models
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
else:
logger.warning("Current model does not support RoPE scaling.")
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
@@ -90,26 +140,26 @@ def load_model_and_tokenizer(
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if not is_trainable: # `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto"
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
# Load and prepare pretrained models (without valuehead).
# Load and prepare pre-trained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
# Disable custom generate method (for Qwen)
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
@@ -122,9 +172,10 @@ def load_model_and_tokenizer(
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
@@ -134,16 +185,19 @@ def load_model_and_tokenizer(
})
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
# Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
print_trainable_params(model)
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
return model, tokenizer

View File

@@ -5,6 +5,7 @@ import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
@@ -19,20 +20,66 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
))
return _parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
))
return _parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
if args is not None:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@@ -48,87 +95,129 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
raise ValueError("Streaming mode should have an integer val size.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
# postprocess data_args
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
# postprocess training_args
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if training_args.optim == "adamw_hf":
training_args.optim = "adamw_torch" # suppress warning
if model_args.quantization_bit is not None:
if training_args.fp16:
model_args.compute_dtype = torch.float16
elif training_args.bf16:
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)
# postprocess model_args
if training_args.bf16:
if not torch.cuda.is_bf16_supported():
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float32
model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary:
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, general_args
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
if args is not None:
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
return model_args, data_args, finetuning_args, generating_args

View File

@@ -1,35 +1,37 @@
import os
import torch
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer):
class PeftModelMixin:
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
"""
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self._remove_log()
def _remove_log(self):
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def __init__(self) -> None: # for type checking
self.model: PreTrainedModel = None
self.tokenizer: "PreTrainedTokenizer" = None
self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
self.state: "TrainerState" = None
raise AssertionError("Mixin should not be initialized.")
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
@@ -42,31 +44,36 @@ class PeftTrainer(Seq2SeqTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()
v_head_state_dict = {
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.")
}
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
backbone_model = getattr(model, "pretrained_model")
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
model = model.pretrained_model
state_dict = state_dict or get_state_dict(model)
if isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
model.config.use_cache = False
else:
backbone_model = model
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if isinstance(backbone_model, PeftModel): # LoRA tuning
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
safe_serialization=self.args.save_safetensors
)
backbone_model.config.use_cache = False
if self.tokenizer is not None:
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
try:
self.tokenizer.save_pretrained(output_dir)
else:
logger.warning("No model to save.")
except:
logger.warning("Cannot save tokenizer, copy the files manually.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
@@ -76,16 +83,25 @@ class PeftTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if isinstance(backbone_model, PeftModel):
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if isinstance(model, PreTrainedModelWrapper):
model.v_head.load_state_dict(torch.load(
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
))
model = model.pretrained_model
if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args

View File

@@ -0,0 +1 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@@ -0,0 +1,51 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@@ -0,0 +1,77 @@
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
**kwargs
):
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logps = self._get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits

View File

@@ -0,0 +1,59 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
training_args.remove_unused_columns = False # important for pairwise dataset
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
generating_args=generating_args,
ref_model=ref_model,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

View File

@@ -2,20 +2,23 @@ import os
import math
import torch
from tqdm import tqdm
from typing import Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel
from transformers import TrainerState, TrainerControl
from trl import PPOTrainer
from trl.core import LengthSampler
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor
from llmtuner.hparams import FinetuningArguments
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
from llmtuner.tuner.ppo.utils import replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
@@ -25,21 +28,24 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: List[LogCallback],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"],
compute_dtype: torch.dtype,
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0]
self.compute_dtype = compute_dtype
self.state = TrainerState()
self.control = TrainerControl()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
self._remove_log()
def ppo_train(self, max_target_length: int) -> None:
r"""
@@ -66,19 +72,16 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate`
gen_kwargs = {
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
gen_kwargs = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
@@ -86,51 +89,38 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
batch = next(dataiter)
steps_trained += 1
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
# Get responses
query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
rewards = self.get_rewards(queries, responses, unwrapped_model)
queries, responses = [], []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default")
# Run PPO step
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
# Run PPO step
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
)
print(logs)
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control)
@@ -147,38 +137,124 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
@torch.no_grad()
def generate(
def get_inputs(
self,
inputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None,
return_prompt: Optional[bool] = True,
**generation_kwargs
) -> torch.Tensor:
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
Subclass and override to inject custom behavior.
"""
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
unwrapped_model = self.accelerator.unwrap_model(self.model)
response = unwrapped_model.generate(**inputs, **generation_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
queries, responses = [], []
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response
return queries, responses
@torch.no_grad()
def get_rewards(
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead"
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
"""
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
if values.size(0) != input_ids.size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
masks[j, :start] = 0
masks[j, end:] = 0
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""

View File

@@ -1,11 +1,10 @@
import torch
from typing import Dict, List, Literal, Optional, Tuple
from trl import AutoModelForCausalLMWithValueHead
from typing import TYPE_CHECKING, Literal
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
valuehead_state_dict = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
@@ -16,22 +15,3 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
"summary.weight": getattr(model, "{}_head_weight".format(target)),
"summary.bias": getattr(model, "{}_head_bias".format(target))
})
def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead,
layer_norm_names: List[str] = LAYERNORM_NAMES,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
layer_norm_state_dict = {}
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
if layer_norm_params is not None:
param.data = layer_norm_params[name] # restore float32 weights
else:
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
param.data = param.data.to(torch.float16)
return model, layer_norm_state_dict

View File

@@ -1,27 +1,30 @@
# Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from trl import PPOConfig
from torch.optim import AdamW
from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
@@ -35,24 +38,35 @@ def run_ppo(
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
total_train_batch_size = \
if finetuning_args.ppo_score_norm:
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps
)
# Initialize our Trainer
ppo_trainer = PPOPeftTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks,
compute_dtype=model_args.compute_dtype,
config=ppo_config,
model=model,
ref_model=None,
@@ -63,8 +77,10 @@ def run_ppo(
lr_scheduler=lr_scheduler
)
# Training
if training_args.do_train:
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -1,32 +1,30 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
import math
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
trainer = PeftTrainer(
@@ -36,12 +34,12 @@ def run_pt(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train()
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
@@ -58,6 +56,5 @@ def run_pt(
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

View File

@@ -1,8 +1,10 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
@@ -15,5 +17,11 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
features = [
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
}
for key in ("chosen_ids", "rejected_ids") for feature in features
]
return super().__call__(features)

View File

@@ -1,13 +1,15 @@
import os
import json
import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
logger = get_logger(__name__)
@@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
def compute_loss(
self,
model: PreTrainedModel,
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -40,13 +42,15 @@ class PairwisePeftTrainer(PeftTrainer):
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
def save_predictions(
self,
predict_results: PredictionOutput
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.

View File

@@ -2,25 +2,26 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
@@ -38,7 +39,7 @@ def run_rm(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Training

View File

@@ -1,7 +1,6 @@
import numpy as np
from dataclasses import dataclass
from typing import Dict, Sequence, Tuple, Union
from transformers.tokenization_utils import PreTrainedTokenizer
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba
from rouge_chinese import Rouge
@@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
@dataclass
class ComputeMetrics:
@@ -16,14 +18,14 @@ class ComputeMetrics:
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
"""
tokenizer: PreTrainedTokenizer
tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -47,6 +49,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@@ -3,13 +3,15 @@ import json
import torch
import numpy as np
import torch.nn as nn
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
logger = get_logger(__name__)
@@ -48,11 +50,12 @@ class Seq2SeqPeftTrainer(PeftTrainer):
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
generated_tokens = (
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
if generated_tokens is not None:
generated_tokens[:, :max(prompt_len, label_len)] = (
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
)
return (loss, generated_tokens, labels)
return loss, generated_tokens, labels
def _pad_tensors_to_target_len(
self,
@@ -70,18 +73,15 @@ class Seq2SeqPeftTrainer(PeftTrainer):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model.")
raise ValueError("PAD token is required.")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(
self,
predict_results: PredictionOutput
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.

View File

@@ -1,25 +1,28 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
@@ -44,21 +47,18 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Keyword arguments for `model.generate`
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_new_tokens": data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
# Training
if training_args.do_train:
train_result = trainer.train()
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
try:
tokenizer.save_pretrained(training_args.output_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
run_exp()

View File

@@ -0,0 +1 @@
from llmtuner.webui.interface import create_ui, create_web_demo

View File

@@ -1,22 +1,22 @@
import os
from typing import List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self, *args):
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
if lazy_init:
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if len(args) != 0:
super().__init__(*args)
else:
super().__init__(args)
def load_model(
self,
@@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str
system_prompt: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
@@ -53,11 +53,11 @@ class WebChatModel(ChatModel):
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
system_prompt=system_prompt
)
super().__init__(*get_infer_args(args))
super().__init__(args)
yield ALERTS["info_loaded"][lang]
@@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
prefix: str,
system: str,
max_new_tokens: int,
top_p: float,
temperature: float
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
response = self.postprocess(response)
@@ -90,6 +90,8 @@ class WebChatModel(ChatModel):
yield chatbot, new_history
def postprocess(self, response: str) -> str:
response = response.replace("<", "&lt;")
response = response.replace(">", "&gt;")
return response
blocks = response.split("```")
for i, block in enumerate(blocks):
if i % 2 == 0:
blocks[i] = block.replace("<", "&lt;").replace(">", "&gt;")
return "```".join(blocks)

View File

@@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache"
@@ -29,12 +29,14 @@ def load_config() -> Dict[str, Any]:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"last_model": "", "path_dict": {}}
return {"lang": "", "last_model": "", "path_dict": {}}
def save_config(model_name: str, model_path: str) -> None:
def save_config(lang: str, model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
if model_name:
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
@@ -46,6 +48,12 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
return "default"
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)

View File

@@ -1,5 +1,6 @@
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box

View File

@@ -1,22 +1,23 @@
from typing import Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel,
chat_model: "WebChatModel",
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
prefix = gr.Textbox(show_label=False)
system = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
@@ -30,7 +31,7 @@ def create_chat_box(
submit_btn.click(
chat_model.predict,
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
@@ -40,7 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
prefix=prefix,
system=system,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,

View File

@@ -1,10 +1,12 @@
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)
@@ -14,6 +16,6 @@ def create_preview_box() -> Tuple[Block, Component, Component, Component]:
close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return preview_box, preview_count, preview_samples, close_btn

View File

@@ -1,24 +1,31 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -28,22 +35,24 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
predict = gr.Checkbox(value=True)
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
output_box = gr.Markdown()
start_btn.click(
runner.run_eval,
[
input_components = [
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
top_elems["system_prompt"],
dataset_dir,
dataset,
max_source_length,
@@ -51,15 +60,21 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
max_samples,
batch_size,
predict
],
[output_box]
)
]
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
@@ -68,6 +83,7 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
max_samples=max_samples,
batch_size=batch_size,
predict=predict,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_box=output_box

View File

@@ -1,11 +1,13 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.utils import export_model
from llmtuner.webui.utils import save_model
if TYPE_CHECKING:
from gradio.components import Component
def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
@@ -14,12 +16,13 @@ def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
export_model,
save_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["template"],
max_shard_size,
save_dir
],

View File

@@ -1,13 +1,15 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
@@ -26,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"]
top_elems["system_prompt"]
],
[info_box]
).then(

View File

@@ -1,138 +0,0 @@
from typing import Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
)
max_grad_norm = gr.Textbox(value="1.0")
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox()
with gr.Box():
output_box = gr.Markdown()
with gr.Column(scale=1):
loss_viewer = gr.Plot()
start_btn.click(
runner.run_train,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
max_grad_norm,
dev_ratio,
logging_steps,
save_steps,
warmup_steps,
compute_type,
lora_rank,
lora_dropout,
lora_target,
output_dir
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)
output_box.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
dev_ratio=dev_ratio,
advanced_tab=advanced_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
compute_type=compute_type,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
output_box=output_box,
loss_viewer=loss_viewer
)

View File

@@ -1,39 +1,46 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict
import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, Component]:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
source_prefix = gr.Textbox(scale=2)
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
system_prompt = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path])
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
).then(
get_template, [model_name], [template]
) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
model_path.change(save_config, [lang, model_name, model_path])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
@@ -41,7 +48,9 @@ def create_top() -> Dict[str, Component]:
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
refresh_btn.click(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
)
return dict(
lang=lang,
@@ -53,5 +62,5 @@ def create_top() -> Dict[str, Component]:
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,
source_prefix=source_prefix
system_prompt=system_prompt
)

View File

@@ -0,0 +1,184 @@
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from llmtuner.extras.constants import STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
padding_side = gr.Radio(choices=["left", "right"], value="left")
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
reward_model = gr.Dropdown(scale=2)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
list_checkpoint,
[top_elems["model_name"], top_elems["finetuning_type"]],
[reward_model],
queue=False
)
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
output_box = gr.Markdown()
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_components = [
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
training_stage,
dataset_dir,
dataset,
max_source_length,
max_target_length,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
max_grad_norm,
val_size,
logging_steps,
save_steps,
warmup_steps,
compute_type,
padding_side,
lora_rank,
lora_dropout,
lora_target,
resume_lora_training,
dpo_beta,
reward_model,
output_dir
]
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
return dict(
training_stage=training_stage,
dataset_dir=dataset_dir,
dataset=dataset,
data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
advanced_tab=advanced_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
compute_type=compute_type,
padding_side=padding_side,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab,
dpo_beta=dpo_beta,
reward_model=reward_model,
refresh_btn=refresh_btn,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
output_box=output_box,
loss_viewer=loss_viewer
)

View File

@@ -3,11 +3,13 @@ from transformers.utils.versions import require_version
from llmtuner.webui.components import (
create_top,
create_sft_tab,
create_train_tab,
create_eval_tab,
create_infer_tab,
create_export_tab
create_export_tab,
create_chat_box
)
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
@@ -22,8 +24,8 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
with gr.Tab("SFT"):
sft_elems = create_sft_tab(top_elems, runner)
with gr.Tab("Train"):
train_elems = create_train_tab(top_elems, runner)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
@@ -34,7 +36,7 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)
demo.load(
@@ -47,12 +49,30 @@ def create_ui() -> gr.Blocks:
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
queue=False
)
return demo
def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"])
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)

View File

@@ -77,7 +77,7 @@ LOCALES = {
"info": "构建提示词时使用的模板"
}
},
"source_prefix": {
"system_prompt": {
"en": {
"label": "System prompt (optional)",
"info": "A sequence used as the default system prompt."
@@ -87,6 +87,16 @@ LOCALES = {
"info": "默认使用的系统提示词"
}
},
"training_stage": {
"en": {
"label": "Stage",
"info": "The stage to perform in training."
},
"zh": {
"label": "训练阶段",
"info": "目前采用的训练方式。"
}
},
"dataset_dir": {
"en": {
"label": "Data dir",
@@ -105,12 +115,12 @@ LOCALES = {
"label": "数据集"
}
},
"preview_btn": {
"data_preview_btn": {
"en": {
"value": "Preview"
"value": "Preview dataset"
},
"zh": {
"value": "预览"
"value": "预览数据集"
}
},
"preview_count": {
@@ -227,9 +237,9 @@ LOCALES = {
"info": "用于梯度裁剪的范数。"
}
},
"dev_ratio": {
"val_size": {
"en": {
"label": "Dev ratio",
"label": "Val size",
"info": "Proportion of data in the dev set."
},
"zh": {
@@ -277,6 +287,16 @@ LOCALES = {
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
},
"padding_side": {
"en": {
"label": "Padding side",
"info": "The side on which the model should have padding applied."
},
"zh": {
"label": "填充位置",
"info": "使用左填充或右填充。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
@@ -315,6 +335,52 @@ LOCALES = {
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
}
},
"resume_lora_training": {
"en": {
"label": "Resume LoRA training",
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
},
"zh": {
"label": "继续上次的训练",
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
}
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations"
},
"zh": {
"label": "RLHF 参数设置"
}
},
"dpo_beta": {
"en": {
"label": "DPO beta",
"info": "Value of the beta parameter in the DPO loss."
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。"
}
},
"reward_model": {
"en": {
"label": "Reward model",
"info": "Checkpoint of the reward model for PPO training."
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。"
}
},
"cmd_preview_btn": {
"en": {
"value": "Preview command"
},
"zh": {
"value": "预览命令"
}
},
"start_btn": {
"en": {
"value": "Start"
@@ -389,7 +455,7 @@ LOCALES = {
"value": "模型未加载,请先加载模型。"
}
},
"prefix": {
"system": {
"en": {
"placeholder": "System prompt (optional)"
},
@@ -513,6 +579,10 @@ ALERTS = {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"err_failed": {
"en": "Failed.",
"zh": "训练出错。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"

View File

@@ -1,6 +1,6 @@
import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component
from typing import Any, Dict, List
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
@@ -12,24 +12,32 @@ class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]:
def gen_refresh(self, lang: str) -> Dict[str, Any]:
refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()}
}
user_config = load_config()
if lang:
refresh_dict["lang"] = {"value": lang}
else:
refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]:
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {}
refresh_dict = self.gen_refresh()
refresh_dict = self.gen_refresh(lang)
for elems in self.elem_list:
for name, component in elems.items():
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
update_dict[component] = gr.update(
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
)
return update_dict

View File

@@ -1,18 +1,20 @@
import gradio as gr
import logging
import os
import threading
import time
import transformers
from typing import Generator, List, Optional, Tuple
from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
class Runner:
@@ -20,49 +22,46 @@ class Runner:
def __init__(self):
self.aborted = False
self.running = False
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def set_abort(self):
self.aborted = True
self.running = False
def initialize(
def _initialize(
self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]:
) -> str:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
return ALERTS["err_conflict"][lang]
if not model_name:
return None, ALERTS["err_no_model"][lang], None, None
return ALERTS["err_no_model"][lang]
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
return None, ALERTS["err_no_path"][lang], None, None
if not get_model_path(model_name):
return ALERTS["err_no_path"][lang]
if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None
return ALERTS["err_no_dataset"][lang]
self.aborted = False
self.running = True
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return ""
logger_handler = LoggerHandler()
logger_handler.setLevel(logging.INFO)
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(
self, lang: str, finish_info: Optional[str] = None
def _finalize(
self, lang: str, finish_info: str
) -> str:
self.running = False
torch_gc()
if self.aborted:
return ALERTS["info_aborted"][lang]
else:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
return finish_info
def run_train(
def _parse_train_args(
self,
lang: str,
model_name: str,
@@ -70,7 +69,8 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
system_prompt: str,
training_stage: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
@@ -82,37 +82,39 @@ class Runner:
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
dev_ratio: float,
val_size: float,
logging_steps: int,
save_steps: int,
warmup_steps: int,
compute_type: str,
padding_side: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
dpo_beta: float,
reward_model: str,
output_dir: str
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
[os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
)
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict(
model_name_or_path=model_name_or_path,
stage="sft",
model_name_or_path=get_model_path(model_name),
do_train=True,
overwrite_cache=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
@@ -127,42 +129,40 @@ class Runner:
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"),
bf16=(compute_type == "bf16"),
padding_side=padding_side,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
resume_lora_training=resume_lora_training,
output_dir=output_dir
)
args[compute_type] = True
if dev_ratio > 1e-6:
args["dev_ratio"] = dev_ratio
if training_stage == "Reward Modeling":
args["stage"] = "rm"
args["resume_lora_training"] = False
elif training_stage == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif training_stage == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta
elif training_stage == "Pre-Training":
args["stage"] = "pt"
if val_size > 1e-6:
args["val_size"] = val_size
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
return lang, model_name, dataset, output_dir, args
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang)
def run_eval(
def _parse_eval_args(
self,
lang: str,
model_name: str,
@@ -170,7 +170,7 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
system_prompt: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
@@ -178,12 +178,7 @@ class Runner:
max_samples: str,
batch_size: int,
predict: bool
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
@@ -194,15 +189,16 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict(
model_name_or_path=model_name_or_path,
stage="sft",
model_name_or_path=get_model_path(model_name),
do_eval=True,
overwrite_cache=True,
predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
@@ -216,23 +212,72 @@ class Runner:
args.pop("do_eval", None)
args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
return lang, model_name, dataset, output_dir, args
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(1)
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang]
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)
def run_eval(self, *args) -> Generator[str, None, None]:
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)

View File

@@ -3,22 +3,30 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, tracker: dict) -> str:
info = log
if "current_steps" in tracker:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps:
return gr.update(visible=False)
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps,
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
)
return info
return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str:
@@ -54,6 +62,18 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True)
def gen_cmd(args: Dict[str, Any]) -> str:
if args.get("do_train", None):
args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
for k, v in args.items():
if v is not None and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text
def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
@@ -87,8 +107,14 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
return fig
def export_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
def save_model(
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
template: str,
max_shard_size: int,
save_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
@@ -114,12 +140,11 @@ def export_model(
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type
finetuning_type=finetuning_type,
template=template,
output_dir=save_dir
)
yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
export_model(args, max_shard_size="{}GB".format(max_shard_size))
yield ALERTS["info_exported"][lang]

View File

@@ -1,17 +1,8 @@
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
from llmtuner import run_exp
def main():
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args)
run_exp()
def _mp_fn(index):

View File

@@ -1,10 +1,10 @@
from llmtuner.webui.interface import create_ui
from llmtuner import create_ui
def main():
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
if __name__ == "__main__":

View File

@@ -1,35 +1,10 @@
# coding=utf-8
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner.tuner import get_infer_args
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
from llmtuner import create_web_demo
def main():
chat_model = WebChatModel(*get_infer_args())
with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
demo = create_web_demo()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
if __name__ == "__main__":