375 Commits

Author SHA1 Message Date
hiyouga
95d0f77fc2 release v0.3.0
Former-commit-id: de7f5b622340ab09ebbe57ad2703e63d06dfdeea
2023-11-16 16:00:11 +08:00
hiyouga
9b2654277b update readme
Former-commit-id: 4018aabc5d1623033d27a8aced25804de79b7e7b
2023-11-16 15:58:37 +08:00
hoshi-hiyouga
f1b3bdac3f Merge #1525 from hiyouga/dev, fix #224 #336 #931 #936 #1011
Refactor llmtuner, support full-parameter RLHF

Former-commit-id: 3b92826803dc69471827b4f8204c2c3dc5310619
2023-11-16 15:47:13 +08:00
hiyouga
595fdbd95d fix css
Former-commit-id: 7afec127f60257462828298b25a5f6fd9c6f42c5
2023-11-16 15:45:38 +08:00
hiyouga
dab9385297 fix bug in web ui
Former-commit-id: a598f145ec903dd2b2c984d951b6c450b142ece5
2023-11-16 15:21:24 +08:00
hiyouga
df83def566 update ppo and demo in webui
Former-commit-id: de7571704c82121db13e3fc907379d2453100191
2023-11-16 14:55:26 +08:00
hiyouga
f9d4e37b3c fix bug in freeze tuning
Former-commit-id: f6b436a08421ca17d64abc51497f4aa43729a43b
2023-11-16 14:25:11 +08:00
hiyouga
e59a3d71e0 tiny fix
Former-commit-id: d65519d8a44b73bbb713741c23465f13c35c83f5
2023-11-16 03:27:19 +08:00
hiyouga
de3a84ac59 fix rlhf callback
Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
2023-11-16 03:26:19 +08:00
hiyouga
e017266b98 fix bug in PPO training
Former-commit-id: 2e99f0e53ce6de0acbcab85dd50aef874e8c6336
2023-11-16 02:32:54 +08:00
hiyouga
f81a8a5e5c fix import bug
Former-commit-id: 2356029cdd120d5f7bf630b80681ce8c53bff90d
2023-11-16 02:27:03 +08:00
hiyouga
7a3a0144a5 support full-parameter PPO
Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
2023-11-16 02:08:04 +08:00
hiyouga
8263b2d32d add demo mode for web UI
Former-commit-id: 5ad34f08b4e1505d7933b973497347f126b2e818
2023-11-15 23:51:26 +08:00
hoshi-hiyouga
833cd490b8 Create CODE_OF_CONDUCT.md
Former-commit-id: 6bee64cdf9c75488033e600fb5b48738daa1ed3b
2023-11-15 20:42:15 +08:00
hiyouga
2162c37e41 update readme and constants
Former-commit-id: 7d83e3dd9101a4fdd0b589d0c1f7b609c0feecd1
2023-11-15 18:04:37 +08:00
hiyouga
b2ac8376e1 support multiple modules in freeze training #1514
Former-commit-id: 60abac70dfd778df2ae8b3a2e960ed8b607d7ab6
2023-11-15 17:08:18 +08:00
hiyouga
8079584143 fix imports
Former-commit-id: 6156f1abef631c675d150dd1cb0325cfc3820c91
2023-11-15 16:47:45 +08:00
hiyouga
09a4474e7f disentangle model from tuner and rename modules
Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
2023-11-15 16:29:09 +08:00
hiyouga
81530133ff fix #1507
Former-commit-id: 1ba9c53bd9743fa95fca1516c0ed9da352dbe9a1
2023-11-15 16:22:32 +08:00
hiyouga
cc4b384ac3 Update cal_lr.py
Former-commit-id: b92ef6c80ae108982046ec1419efb67c8b10b250
2023-11-14 21:14:42 +08:00
hiyouga
3852daf447 Update cal_lr.py
Former-commit-id: b6c3f9b24324403db41c5680a00aabc6d53bbeb9
2023-11-14 21:13:01 +08:00
hiyouga
5c97111f9d Update cal_lr.py
Former-commit-id: 1258eec806f6f4580a6eb7d9eb44f431f4c0da4f
2023-11-14 21:09:30 +08:00
hiyouga
75dd1f0f7e add cal_lr.py
Former-commit-id: cea2ba17efc47917e63437a376f220864f7f90dd
2023-11-14 20:58:37 +08:00
hiyouga
c9a4551012 fix #1494
Former-commit-id: 07c8d734529f03e47ef638a1bda222e8824d3d38
2023-11-14 18:07:20 +08:00
hiyouga
87197ba91d fix #1489
Former-commit-id: ebdeaca9cdfd6138c690a0fcb9f676deaddff177
2023-11-14 15:27:05 +08:00
hiyouga
7461bf84e5 support eval remote dataset
Former-commit-id: 71dd2698bf8c0b9ef7af995fb1e49e39fa66074e
2023-11-14 02:42:30 +08:00
hiyouga
fbc0357b2e fix dc link
Former-commit-id: 04c3a1f1c98d8f191102e359def0c8dcdc9621e3
2023-11-13 23:22:56 +08:00
hiyouga
ec334f5891 release v0.2.2, fix #1478 #1466
Former-commit-id: c9534c411716e1dceb54c5eb35fe845c93ee2973
2023-11-13 23:09:05 +08:00
hiyouga
885efe772e fix #424
Former-commit-id: ca24d445f825e120e659f5cd080a954c2243b8f2
2023-11-13 22:42:23 +08:00
hiyouga
64fc9ba678 refactor evaluation, upgrade trl to 074
Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
2023-11-13 22:20:35 +08:00
hiyouga
989eccd286 fix flashattn warning
Former-commit-id: 6eb095d39bd82fdbdb729a0ea57fc7246e3a60d6
2023-11-10 18:34:54 +08:00
hiyouga
f0766a2ab0 add todo
Former-commit-id: 0bd884feb11736d0ab24ca19885151cb47d9dcd3
2023-11-10 14:38:18 +08:00
hiyouga
178b85ff9a refactor constants
Former-commit-id: a4d4c3fd35276f20e3b354e9d13ea971029c8775
2023-11-10 14:16:10 +08:00
hiyouga
68dd1ef121 tiny fix
Former-commit-id: 97ba2027bb1ddc01a3c824c40d5a180828810c2c
2023-11-09 17:20:49 +08:00
hoshi-hiyouga
b222cffe98 Merge pull request #1454 from yyq/main
Update finetuning_args.py

Former-commit-id: e67d8b93705383a8590f99e26e9fe8f663712aef
2023-11-09 17:12:18 +08:00
Yanqing
b4f1ab93d1 Update finetuning_args.py
更新 chatglm/falcon/bloom 的 lora_target 的名称

Former-commit-id: 06606739af035a80ae9ddba9d12c965ed289305d
2023-11-09 17:04:40 +08:00
hiyouga
f2e139f5cd fix #1452
Former-commit-id: 4d16214467715df458e24d03bb7d303d62b8bdcd
2023-11-09 16:41:32 +08:00
hiyouga
a9cbca1604 update readme
Former-commit-id: f7ead54042868550a3e8a6928ea3c0e2673f15b3
2023-11-09 16:00:24 +08:00
hiyouga
3a30ce6c16 release v0.2.1
Former-commit-id: 1c30f2be0140f5ab47c2bc811170d0271a0cdad6
2023-11-09 15:54:16 +08:00
hiyouga
48ec5355f9 add template, modify datasets
Former-commit-id: 81e54beb4d0f792f4fd7f450643caaf10f2f0b7d
2023-11-09 15:53:23 +08:00
hoshi-hiyouga
11859bc322 Merge pull request #1436 from lvzii/main
fix tokenizer config changed after pretrain

Former-commit-id: f485c3983e413fd3a3a57b451800705b072869a7
2023-11-09 14:30:50 +08:00
hiyouga
28c67a5be8 support parquet format #1446
Former-commit-id: 44a3b9ac9f10d2012b8ad3d8c48123db9a0da2f1
2023-11-09 14:17:40 +08:00
hiyouga
44fe93e9b0 fix #1438 #1439
Former-commit-id: 84260d58dda22adc32c26bc943ed2a36fd01341d
2023-11-09 13:45:10 +08:00
lvzi
09a1681b63 fix tokenizer config changed after pretrain
Changing tokenizer's attribute at preprocessing stage will result in saving a wrong tokenizer.
for example, baichuan2

Former-commit-id: 19942b5314b84267691f0a5657d0679f2ddbe58b
2023-11-08 15:50:46 +08:00
hiyouga
f5ba2190fb fix ppo train and dpo eval
Former-commit-id: ced863031836632cb5920e22ae6991f251372118
2023-11-07 22:48:51 +08:00
hiyouga
14a38b5069 fix #1422
Former-commit-id: 25d7bbd0a5142f001bd2ff498df07b24137050a9
2023-11-07 19:42:01 +08:00
hiyouga
f23e5b602a fix reward model loading
Former-commit-id: 9709ca501180a1afce32e9043aedb359762b437d
2023-11-07 17:20:51 +08:00
hiyouga
857696ed9c fix args
Former-commit-id: 44d0fa2ac6a6423c7ddaf91eb8998c1b9248c04e
2023-11-07 16:36:06 +08:00
hiyouga
2084133058 update info
Former-commit-id: 89643b8ac1e3fa8d2f29f1c88e4d4503410c0d05
2023-11-07 16:28:21 +08:00
hiyouga
f7f0c3070e delete file
Former-commit-id: 7d6355db0fd5809b99f3fa42753cf4dffd251fd1
2023-11-07 16:20:12 +08:00
hiyouga
46235aa514 fix #1418
Former-commit-id: 9bfecc72c53cf95fea4a9ff02ec40a65da6d4f54
2023-11-07 16:17:22 +08:00
hiyouga
2eb65d21ac upgrade peft, fix #1088 #1411
Former-commit-id: aa7d104f8e050d12cb8f585bc8a52c850995500f
2023-11-07 16:13:36 +08:00
hiyouga
37a0d62a82 update requirements
Former-commit-id: 82ebbbbb80b3f3f616274210970738d0f44b5a0a
2023-11-06 19:01:21 +08:00
hiyouga
21ac46e439 use seed in evaluate.py
Former-commit-id: ab5cac1dfa681933f3266827f80068ce798b4c56
2023-11-06 18:17:51 +08:00
hiyouga
ba3e8ba20c update readme (list in alphabetical order)
Former-commit-id: e6a67b5477ee095bd92764581cfe6af57e799a69
2023-11-06 17:18:12 +08:00
hiyouga
2c48e798ca update templates
Former-commit-id: 85be2e242b062283f192c4c4d0715dc1e8a68589
2023-11-06 12:25:47 +08:00
hiyouga
4e40f5b62b fix #1383
Former-commit-id: 9b8a782aa80f27c3e2a2e2621f9be17cae1a27e8
2023-11-06 11:42:23 +08:00
hiyouga
2a8892b785 fix deepseek template
Former-commit-id: 1fdbcdad9a1cdb20299350efd87a8e5cb8c625a3
2023-11-05 13:08:46 +08:00
hiyouga
ee3b33ff03 support deepseek coder #1378
Former-commit-id: ae0c829917b9de10e71199c85c77a52cdcd2b7b3
2023-11-05 12:51:03 +08:00
hiyouga
b2c3001f8e fix #1365
Former-commit-id: 0277d120e62164bb7fa1d6043b8fcc52c881fe96
2023-11-05 12:21:07 +08:00
hiyouga
6cfe1e1ac2 tiny fix
Former-commit-id: 594c510a20d6c2782d7b7ffff18931e3003e6c22
2023-11-03 01:26:06 +08:00
hiyouga
52326870e4 fix #1290
Former-commit-id: ad911d258c4cea16f54d09bc192e076c21d26394
2023-11-03 00:44:53 +08:00
hiyouga
217fde0918 fix bug in data loader, support dpo eval
Former-commit-id: f4f3dcff990468a2fa864b7176adcebbcf16dac9
2023-11-03 00:34:26 +08:00
hiyouga
065021d82a update data readme
Former-commit-id: 6a65ef44ed58714c611da60b5af96b85352e8735
2023-11-03 00:15:23 +08:00
hiyouga
4bb643e685 update data readme (zh)
Former-commit-id: b32fb3a984c681732b82f6544d6c05a98c34cf4c
2023-11-02 23:42:49 +08:00
hiyouga
b77c745b1a support sharegpt format, add datasets
Former-commit-id: 202daf8987ccb7523be03ca535b572b5c9e65994
2023-11-02 23:10:04 +08:00
hiyouga
7d13501b94 support pagination in webui preview
Former-commit-id: f2307e26b9c2ce5d60917cce5a9638466ea676c8
2023-11-02 21:21:45 +08:00
hiyouga
ac74639b32 fix webui
Former-commit-id: 9192948fa221c0275ddfa579ef6b3442d45b8962
2023-11-02 18:03:14 +08:00
hiyouga
12fa56ae68 support warning in webui
Former-commit-id: 9903b523fad2f0ec0e66c3d313823bd4674bfa2b
2023-11-02 17:57:04 +08:00
hiyouga
f11b863f4b fix #1349
Former-commit-id: 556c023eab2a68560b26a7d5318a79410fb0c700
2023-11-02 17:02:44 +08:00
hiyouga
f3e4b72957 fix #1356
Former-commit-id: d2ed436108a339d405dad1be1ca15baca3d6d3e4
2023-11-02 16:51:52 +08:00
hiyouga
8d52fb46ca fix #1325
Former-commit-id: 59f2cbbd52d4646fbd1ba83032bf522ecc49a50f
2023-11-01 23:38:49 +08:00
hiyouga
dab8f45033 fix chat
Former-commit-id: 68f2b3df09c4c8638b9e225fd5b8aed3541e97a0
2023-11-01 23:07:58 +08:00
hiyouga
bff8b02543 update gradio, support multiple resp in api
Former-commit-id: a34263e7c0e07a080276d164cdab9f12f1d767d2
2023-11-01 23:02:16 +08:00
hiyouga
2406200914 fix SFT trainer
Former-commit-id: bf09b6a6cd75cc2738d9af6b8c30bcbba77fa9b5
2023-10-31 21:52:52 +08:00
hiyouga
db06fcfc84 fix #1316
Former-commit-id: 88a753fe80e277007bac2264aee24024e18f2314
2023-10-31 11:32:08 +08:00
hiyouga
93b9f74e9f update projects
Former-commit-id: 33d58e9171ad2693b9d54715eb61a6f4326c59f4
2023-10-29 22:53:47 +08:00
hiyouga
33ec844f76 add projects
Former-commit-id: 495a68cd5962dd3b3af7e4a920d91ac25531a862
2023-10-29 22:07:13 +08:00
hiyouga
0f727b393e update constants
Former-commit-id: ebacbb1072045924a7e335cc9dda488d6f0be8b3
2023-10-29 13:30:20 +08:00
hiyouga
7da2aad6ee fix vicuna template
Former-commit-id: a98eda0803e4b73a24f12d848e14161451921e98
2023-10-27 22:15:25 +08:00
hiyouga
6f09f50d02 fix chatglm3 template
Former-commit-id: 69bcbc9f6c98e4f4ad97ec0306b33ab21923d311
2023-10-27 21:12:06 +08:00
hiyouga
5919832059 update readme
Former-commit-id: 6fb92c7088316c56ce8656e540fc47b0a5a1bf18
2023-10-27 19:19:03 +08:00
hiyouga
f7635c1afc support chatglm3
Former-commit-id: ba82e13bbeed3b262d301196b1860d73f319401d
2023-10-27 19:16:28 +08:00
hiyouga
c762168ed0 support dataset cache
Former-commit-id: f79ee62eb4a2a4a01cb4e2a6aa2d07158cf8eb59
2023-10-26 21:48:45 +08:00
hiyouga
67a46e553f fix #1287
Former-commit-id: d885aca472c6448bbf9a9e8d16bead92038825e3
2023-10-26 17:49:41 +08:00
hiyouga
e406f37b54 fix #1285
Former-commit-id: 2f8fe4439506e844b147fe38b5eb878c5748c31c
2023-10-26 16:34:52 +08:00
hiyouga
62fe877124 remove filter in preprocess
Former-commit-id: 9eac08b35fec47129a29c401ca265343f8388ab0
2023-10-23 23:46:02 +08:00
hiyouga
a0e682ba79 update neftune logic
Former-commit-id: bb4f0589ed23bf0236d3e918272ad64f0a05ef39
2023-10-22 17:42:13 +08:00
hiyouga
49e8a87383 fix webui
Former-commit-id: a5a5a7bc1f53d36e1b26e418999465903cb7d9ed
2023-10-22 17:24:56 +08:00
hiyouga
b2764b49ca add new options in webui
Former-commit-id: 6698b832dd9cc2d7d60be4fa5ab90e34a7e9d8e0
2023-10-22 17:17:58 +08:00
hiyouga
06b810de8f fix recursion error
Former-commit-id: c7938188c36a71a878bca982b7dd151195164986
2023-10-22 16:28:37 +08:00
hiyouga
6da51565f5 reimplement neftune
Former-commit-id: efe9e5a194d3a9f052701d904715238816e4c09e
2023-10-22 16:15:08 +08:00
hoshi-hiyouga
1f69965239 Merge pull request #1252 from anvie/neftune
add NEFTune optimization

Former-commit-id: 85d5c5fbe731f486c3e83812227fa05edc131487
2023-10-22 15:59:20 +08:00
anvie
af2d61178d add NEFTune optimization
Former-commit-id: 603e0298af64116ac07130fe6661a9ba823c186c
2023-10-21 13:24:10 +07:00
hiyouga
6a955ccf4f fix openchat template
Former-commit-id: 88b9b657bc50495ac4c42f64195fc652fe4ca3df
2023-10-21 01:25:42 +08:00
hiyouga
c0658711ca fix tokenizer padding side in evaluate.py
Former-commit-id: bcb43ff8ba1946c1f7e7865c9d0fb47ba276935d
2023-10-21 00:30:04 +08:00
hiyouga
d602f06882 fix #1232
Former-commit-id: 49975755d47344e362145c52548fdda8783f2c0c
2023-10-20 23:28:52 +08:00
hiyouga
1cb9a38ac2 fix #1215
Former-commit-id: d91b43a8afbea4859357f2224e3d9b9d71160e6d
2023-10-19 16:19:21 +08:00
hiyouga
47a1f73d0f fix #1218
Former-commit-id: b301f35bd4a3bf368159c8f5fb4e2736f922115b
2023-10-19 16:17:41 +08:00
hiyouga
142dd63b47 fix #1228
Former-commit-id: e4e0cae3f55da2f1b566c97dbfdd7fc5b7b728a4
2023-10-19 15:54:10 +08:00
hiyouga
b1bd8370c2 fix #1217
Former-commit-id: 065fc0a6f3f005bb87e1c5c126c8b6bb470ce700
2023-10-19 15:52:24 +08:00
hiyouga
215660c8da rename webui
Former-commit-id: 26feaf80fff6177d9eb4e28ad18feb6d34d3ea27
2023-10-16 15:16:24 +08:00
hiyouga
0cafe67efe fix #1197
Former-commit-id: 00100e23fcfef9587fda4cf01c62599d996e1176
2023-10-16 15:13:46 +08:00
hoshi-hiyouga
ea83b3222b Update README_zh.md
Former-commit-id: 3450404bb9a33c3bd4b45ac4afcf51062f8c7d1d
2023-10-16 00:28:27 +08:00
hoshi-hiyouga
725087a04f Update README.md
Former-commit-id: d84896597eded79f78224faed81cc9f2df222978
2023-10-16 00:23:37 +08:00
hiyouga
d627ab4855 release v0.2.0
Former-commit-id: 7f941c1ab6c52915aa2675fa77cae5efc530fdd9
2023-10-15 20:49:43 +08:00
hiyouga
7d867e8df4 update readme
Former-commit-id: a99a92b129a3d2372e66ca73b87c3e521f144043
2023-10-15 20:28:14 +08:00
hoshi-hiyouga
3d34d44497 Update README.md
Former-commit-id: e6fcc1831dadd2ec2c0acb14697a35f6471139ab
2023-10-15 20:23:22 +08:00
hiyouga
a6f800b741 fix config, #1191
Former-commit-id: 5dbc9b355e85b203cb43ff72589374f0e04be391
2023-10-15 18:28:45 +08:00
hiyouga
a003d1fa1e disable tqdm in webui mode
Former-commit-id: 832be571bec2eefb79ea88f110b7827f5c1249e6
2023-10-15 16:18:25 +08:00
hiyouga
c2e84d4558 refactor export, fix #1190
Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
2023-10-15 16:01:48 +08:00
hiyouga
68330eab2a fix eval resuming in webui
Former-commit-id: b28b53cd06777f213ef7b925a914ff5fd357ade1
2023-10-15 15:45:38 +08:00
hiyouga
7070f3969d tiny fix
Former-commit-id: 47b7b34357708a5354d542ddc239146c6417d718
2023-10-15 05:02:48 +08:00
hiyouga
e4727ab155 fix callback
Former-commit-id: 51208655a8c1d66551b7b644247321a3583debdc
2023-10-15 04:59:44 +08:00
hoshi-hiyouga
280e7d97ad Merge pull request #1186 from hiyouga/dev
Support Web UI resuming training

Former-commit-id: fcbecd0c4cb17b883e9b780a71d2abc38228293e
2023-10-15 04:53:14 +08:00
hiyouga
31e3805fb8 implement webui resuming training
Former-commit-id: 2d41672ef52414c56c50c8b4fdc442797ba682e9
2023-10-15 04:52:19 +08:00
hiyouga
ef248dbe15 fix bugs in webui
Former-commit-id: 4befa74ea630d90e4d7a1f7d7c34d39257717ec1
2023-10-15 03:41:58 +08:00
hiyouga
6a61b4b638 refactor webui
Former-commit-id: 813ecd8e51949c21ab6fbaa51cc2b1a84ee07952
2023-10-15 03:06:21 +08:00
hiyouga
4b1473502f fix loading dtype
Former-commit-id: d54a356128f7e335c12089702cf3de7f5b4baf16
2023-10-14 20:15:24 +08:00
hiyouga
bf211d818d fix #1176 #1177
Former-commit-id: 5627a2b57c270a78095a32083e2dc7aa02162875
2023-10-14 20:00:17 +08:00
hiyouga
27dd87c890 fix #1184
Former-commit-id: 5b069a967823e659dbc70b0d50361b3ad248087e
2023-10-14 19:20:11 +08:00
hiyouga
8659084ab0 fix webui
Former-commit-id: a0fe43aac968d9f6ca4724b8d718b45c03063b91
2023-10-13 16:27:59 +08:00
hiyouga
e1c9dcea93 update readme
Former-commit-id: 9d9018fad314cdc4512b4847633489cdd7a25347
2023-10-13 13:53:43 +08:00
hiyouga
171339ab17 update discord link
Former-commit-id: f725cb4940a3a18e9f1edca986ef06d425b39710
2023-10-12 21:44:28 +08:00
hiyouga
8542ba5c69 rename repository
Former-commit-id: 6100ac080a5e52edd66b98147aede6cb77481beb
2023-10-12 21:42:29 +08:00
hiyouga
97b74d328b fix ppo args
Former-commit-id: 0f12899951808f53a482082eb116bda309775930
2023-10-11 23:40:50 +08:00
hiyouga
3198a7e5f4 refactor model_dtype, fix PPO trainer
Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
2023-10-11 23:16:01 +08:00
hiyouga
a2d08ce961 add averaging in evaluation
Former-commit-id: b39d6e0b8658e1c69bbaf6bcb6cfaa8f7af30110
2023-10-10 23:16:31 +08:00
hiyouga
bd8ea09479 fix aquila template, repair sft packing mechanism
Former-commit-id: 8c82cfa5dd4bec957426b5bf176d242c77552ab0
2023-10-10 18:49:55 +08:00
hiyouga
6d0d46c7fb tiny fix
Former-commit-id: 31ccd3329ac634b239c43d60bd955cd95670df16
2023-10-10 17:41:13 +08:00
hiyouga
820540780a update readme
Former-commit-id: 4a9c8a4f18b07455c34e6c1e6bbc81cbefd82eea
2023-10-09 20:02:50 +08:00
hiyouga
f74d600497 fix flash shift short attention
Former-commit-id: e44ad23eafa39b3ac0400b6f97cd440106a87f44
2023-10-09 17:54:48 +08:00
hiyouga
94fec9f50e fix webui args
Former-commit-id: 64aa75c8cd7c84ab4a0f1dbaf4763765ba973f54
2023-10-09 17:13:57 +08:00
hiyouga
e387a50475 fix shift short attention
Former-commit-id: 9a49cce8e6f6b222f74a07bdab40efee6a77b0f1
2023-10-09 17:07:46 +08:00
hiyouga
5c4248a29c update webui #1086
Former-commit-id: 65a48bc398f18f71f5f2659b2070e3b9593af243
2023-10-09 14:50:14 +08:00
hiyouga
f22886e2b6 fix #1097
Former-commit-id: c5b8796322d9d48e815038f9fecf0ce39036a4ee
2023-10-08 22:29:26 +08:00
hiyouga
33af3cbf37 add llamafy_qwen.py
Former-commit-id: 6cdc91543c022edcc98076488f06e809fde9bad7
2023-10-08 22:05:36 +08:00
hiyouga
728dfb1be7 fix #1068 #1074
Former-commit-id: 26c6bfd21de06cc56be9a58e2ef69045ea70cc14
2023-09-28 14:39:16 +08:00
hiyouga
e49f7f1afe fix bug in packed sft dataset
Former-commit-id: 51d26b2af6612e65a91c576da5270028da27b322
2023-09-28 01:16:46 +08:00
hiyouga
21a454fa6c tiny fix
Former-commit-id: 35b355b76d2a8f8adf3750a905224e52d03d218f
2023-09-28 01:03:04 +08:00
hiyouga
22c6c27f78 tiny fix
Former-commit-id: 7451b2ae7e58d0f1857f01a037672a8c53b1bd0d
2023-09-28 01:02:11 +08:00
hiyouga
aecbb43096 fix #1064
Former-commit-id: fd4660aa72d981d7efdad465f24a59358626c975
2023-09-28 00:53:29 +08:00
hiyouga
fa53fd2db2 fix bug in pretraining
Former-commit-id: 18a2d90bd6e7c3e1e3513e6f9d895e4048b35b04
2023-09-28 00:45:20 +08:00
hiyouga
1c150995ae fix layer norm dtype
Former-commit-id: 67af21961b68d9b54d07b09e444c7140869f26da
2023-09-28 00:25:55 +08:00
hiyouga
6c5d8f089e fix #1026
Former-commit-id: d0940d0dbd03d4bbcc955304566b0d5507edf9e6
2023-09-27 22:57:09 +08:00
hiyouga
dd623325e8 fix #424
Former-commit-id: daaf89f1126112a73b9f115b0f5617a8cd974a3e
2023-09-27 22:49:43 +08:00
hiyouga
e8a375c8f2 fix #1032
Former-commit-id: 1235b2da5a79ffefd1342054ea8e7dabf47398c1
2023-09-27 22:42:16 +08:00
hiyouga
386d85ae72 refactor finetuning Args
Former-commit-id: be425a70a4c8f051717cf1e4464dbd79dae4c0b5
2023-09-27 22:28:06 +08:00
hiyouga
ebb3901b05 update readme
Former-commit-id: badbc210435d92cea8799bcd1af4c738da902cd7
2023-09-27 21:57:47 +08:00
hiyouga
20130b486c support LongLoRA
Former-commit-id: 0832ed37e7947d699f17375648a52f80752c2b6b
2023-09-27 21:55:50 +08:00
hiyouga
73c48d0463 add CMMLU, update eval script
Former-commit-id: 47f31f06a946eefa5a972e4a566cf3ce05e1e111
2023-09-23 21:10:17 +08:00
hiyouga
f7cecd20e3 update evaluate
Former-commit-id: 288137a76ed1528faa39b467da22f6468ba368ee
2023-09-23 11:55:31 +08:00
hiyouga
2bc64a7636 move file
Former-commit-id: 8711ca9b5421f971ee4cb2fada23832f1021577c
2023-09-23 11:52:12 +08:00
hiyouga
9564ddbb48 shuffle few shot examples
Former-commit-id: 2c9c14c122382e640dfa41a3799628c764c99457
2023-09-23 00:53:20 +08:00
hiyouga
28062c71b5 fix MMLU
Former-commit-id: eeab92323899694010469451b8dfb1f00d685bff
2023-09-23 00:42:23 +08:00
hiyouga
35d1921081 add MMLU and C-Eval script
Former-commit-id: 3403f876127b4b99c5e3edb2834cc3b9a3a0063f
2023-09-23 00:34:17 +08:00
hiyouga
4fbdf18c70 fix #1000
Former-commit-id: 85de2d0a99e4a81fae890a963ccbb5c6142d52d4
2023-09-22 15:00:48 +08:00
hiyouga
5e07ab01f0 update readme
Former-commit-id: 776f9ea3a5837cb3f80ebe53f19e9951400bf05d
2023-09-22 14:34:13 +08:00
hiyouga
fac465a21e fix webui
Former-commit-id: e28485b476816c1bd6c34f7ff9efaa9e3fb85176
2023-09-21 19:55:38 +08:00
hiyouga
e145a2ce0c tiny fix
Former-commit-id: d24ea58c1a44b94227f4cb60f13fc1dd79997d01
2023-09-21 19:52:06 +08:00
hiyouga
dc68c313ee fix #944
Former-commit-id: 032245647848aaa4167086636b6c985268c5fee3
2023-09-21 19:51:02 +08:00
hiyouga
95c0d9ab24 tiny fix
Former-commit-id: 1a7ddd8c1d20dc251f53923bd0ab9f3f1031dd21
2023-09-21 15:25:29 +08:00
hoshi-hiyouga
46a718f339 Merge pull request #975 from statelesshz/npu-support
Add Ascend NPU support

Former-commit-id: b348c7569c0d3f46b03fb274226444ac7a80e68d
2023-09-20 14:56:50 +08:00
statelesshz
496ba46960 support export model on Ascend NPU
Former-commit-id: 50f94e6d9d62c848db7a3db85fa999d67ddd9f04
2023-09-20 10:26:02 +08:00
hiyouga
43ae0aca1d fix webui
Former-commit-id: 2aa06a5a74d98ec25ed6e1e39df11230670f5bad
2023-09-19 18:35:21 +08:00
hiyouga
b8574c1b82 fix error info
Former-commit-id: b90ed220c5e94086d2b73045eff2440ff1b58c5c
2023-09-19 18:30:23 +08:00
hiyouga
32f8b1082b add tests.cal_flops.py
Former-commit-id: 47a119db6c6e937f6ed96f70e3cda6031b9fbd0d
2023-09-16 23:40:41 +08:00
hiyouga
6443fef31a update readme
Former-commit-id: 813c2df5dc179d82c6c999f63c2640e7c3f6aaff
2023-09-16 17:33:01 +08:00
hiyouga
14c3795a7d fix #913
Former-commit-id: d67c11d69277292648dd9889a7321345e2c0c437
2023-09-15 20:58:28 +08:00
hiyouga
3d9e2de573 fix #896
Former-commit-id: 4b70d623d817460de4732749110622e4a1b51958
2023-09-14 18:37:34 +08:00
hiyouga
0ca36a0f8d fix #887
Former-commit-id: e131bc03e05ccae3c6ad8bb42ccf2cdcc2cf3cea
2023-09-14 17:56:58 +08:00
mmbwf
3e5555502a Update utils.py
Fix parameters load error.

Former-commit-id: 112850364c7fdb53e3a38d42861404fc519108ce
2023-09-14 15:38:04 +08:00
hiyouga
fbf5b5e0a9 add MathInstruct dataset
Former-commit-id: 3d1d4b47055739854cf9788a902607e1bbba3723
2023-09-13 22:30:14 +08:00
hiyouga
3305e66f8c fix ppo save model
Former-commit-id: 300ca6d904524f46cb520056e1319a1e9a13d169
2023-09-12 16:25:29 +08:00
hiyouga
e19a44c12b fix #762 #814
Former-commit-id: 9a30ee5009040afbc524dbac0dad99904b2adf5f
2023-09-12 16:10:10 +08:00
hiyouga
8b0e6b9d1b tiny fix
Former-commit-id: d8ea0691f84c971e6860526714fc9873c350b064
2023-09-11 18:27:08 +08:00
hiyouga
f3e638ac6a Release v0.1.8
Former-commit-id: d9666411375964d334d0a93ec162b27e05f70d49
2023-09-11 17:31:34 +08:00
hiyouga
42e0b30476 update flashattn, fix ppo save model
Former-commit-id: 0b08bc3dac246d4aa3f89afb7172529dcad9c39f
2023-09-11 17:25:36 +08:00
hiyouga
a09a7b650d remove PeftTrainer
Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
2023-09-10 22:23:23 +08:00
hiyouga
332d7bbd56 truncate readme
Former-commit-id: fed5d0cc87e4a5a023f2edae622f2820bded1509
2023-09-10 21:04:20 +08:00
hiyouga
d3b6fece71 update readme
Former-commit-id: c42fe77fec2918fe8811d48ec88e9a7c1e6f07ab
2023-09-10 21:01:20 +08:00
hiyouga
9d963b82de update readme
Former-commit-id: b4109cfe548e091cd20fa84815dce5ff3974a090
2023-09-10 20:52:21 +08:00
hiyouga
a402161631 support FlashAttention2
Former-commit-id: 23e56c5554b948d4f08ad87849b261eafd2c7890
2023-09-10 20:43:56 +08:00
hiyouga
b481ad58e6 fix #850
Former-commit-id: e5975c4c6b8bd47ec506b0d4a4703bee05495436
2023-09-10 14:22:03 +08:00
hiyouga
f91c5f2638 fix lora target
Former-commit-id: d822e41e7ac7e310ee49e347fc45754284ce30b8
2023-09-09 17:04:45 +08:00
hiyouga
7143c551ab support lora target auto find
Former-commit-id: bce9984733d88bf013847eed523d1c75fdf0995e
2023-09-09 15:38:37 +08:00
hiyouga
50e93392dd fix chatglm2 tokenizer
Former-commit-id: 1ab60b4a93fa1be5dfe6ffbd4deb64c0f9d9b431
2023-09-09 13:50:29 +08:00
hiyouga
9f83e93839 add baichuan2 convert script
Former-commit-id: 4d676e0ea9e59c1be13ecb47734917ba78938ac8
2023-09-08 22:59:41 +08:00
hiyouga
692b132dbf fix bug in DPO data collator
Former-commit-id: 4fc262cdf1347691e253bdfbd96568db5a49c086
2023-09-08 20:45:07 +08:00
hiyouga
e70b3e8947 fix #761
Former-commit-id: be76f6cbe5143f781b6b39603b80392253b3080a
2023-09-08 20:22:18 +08:00
hiyouga
612d97db6f change to right-padding, update reward score #803
Former-commit-id: baa90415bc8f5ebd423d001378b51c3a3a6c2ec7
2023-09-08 20:04:31 +08:00
hiyouga
bb1b67c076 fix chatglm template
Former-commit-id: 69a824628b4d6a56a680a7e713b217877c6c15c5
2023-09-08 14:45:58 +08:00
hiyouga
5a75c31caa update requirements
Former-commit-id: d796a4a5709c390629bafbeb7c91fccf6a9076d0
2023-09-07 19:26:25 +08:00
hiyouga
8b9210286b fix #818
Former-commit-id: e81fd458c279ed2f3cee780e517482b425c8886d
2023-09-07 19:19:53 +08:00
hiyouga
b5acec34f7 add deepspeed check in PPO training
Former-commit-id: e203ec7f71f504ccbaa89c27d20b8a0d9fa53f7e
2023-09-07 19:12:40 +08:00
hiyouga
86d835878c fix #809
Former-commit-id: 2783ca75365d7c373cefba039788a48f0b8f35fc
2023-09-07 19:04:32 +08:00
hiyouga
eae7b331d3 fix baichuan templates
Former-commit-id: f48a49e835b32f3991cfad8874c7b9c78953809f
2023-09-07 18:54:14 +08:00
hiyouga
ed89e29bcc update baichuan2 template
Former-commit-id: 16d9f8ba176443c5b397233da621600d6e1e1eec
2023-09-06 21:43:06 +08:00
hiyouga
c2b1886aff add Baichuan2 models
Former-commit-id: 90b3f02c44c0b8cc1b59f37af3a1ec28874a8a61
2023-09-06 18:40:11 +08:00
hiyouga
218f36bca5 add Baichuan2 models
Former-commit-id: 36960025e9274b574f57e7a7bf453cd96956e922
2023-09-06 18:36:04 +08:00
hoshi-hiyouga
b91fc1f5b3 Merge pull request #786 from kinghuin/patch-1
fix utils.py bug

Former-commit-id: 26aad616340748e1594a60119ca9434908bf7465
2023-09-05 10:49:34 +08:00
Q
2a22bf9c15 fix utils.py bug
Former-commit-id: dc490117d50c3cbc070b804bac89400f4290272f
2023-09-05 10:38:01 +08:00
hiyouga
62e2037125 fix #763
Former-commit-id: e424b928a35097b783af879a2290f59b2158801d
2023-09-01 23:13:05 +08:00
hiyouga
e5b72c6a77 refactor dataset_attr, add eos in pt, fix #757
Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
2023-09-01 19:00:45 +08:00
codingma
93be211f80 Merge pull request #741 from hiyouga/feature-addDatasetCheck
Feature add dataset check

Former-commit-id: 4b6dabe73d2c7edc94cd495390577c8bcf88428b
2023-08-31 20:57:36 +08:00
codemayq
9ae3fb4ced update llama2 template
Former-commit-id: 01de1d51d9fa5a22a338b6ed18ffad4d0ad5e3e8
2023-08-30 16:23:56 +08:00
codemayq
f641075789 add dataset stage check
Former-commit-id: 5c719a7ce988339d034a653456da9742dc2cec7c
2023-08-30 16:23:08 +08:00
codingma
f7658db1b6 Merge pull request #651 from hiyouga/feature-dataset_stage
add dataset stage

Former-commit-id: 3b0ef57405cbc22ff8ce4eef2cfcb73872519db5
2023-08-28 16:03:45 +08:00
codemayq
b869bc1a20 add ad gen dataset
Former-commit-id: fcd0788aa4dda0cecc1420d369d371032a207810
2023-08-27 20:35:32 +08:00
codemayq
a72d756d77 add text format dataset preview in webui
Former-commit-id: cd30871aadb40cd3d598a6d0b415946744d2d550
2023-08-24 19:45:36 +08:00
codemayq
d3fd8f89b8 add stage in DatasetAttr
Former-commit-id: 9c55200d8de0623640f529dbf39b8b0f169636d3
2023-08-23 20:54:53 +08:00
hiyouga
180a05a446 fix import error
Former-commit-id: b3207a974a45038591b8cbbcf20d1ca1142d6679
2023-08-23 20:45:03 +08:00
hiyouga
eb9ac9ee1f fix #649
Former-commit-id: e6120a937ddb4f3c0b9bcb2466742f5cf4f77f8c
2023-08-23 20:21:15 +08:00
codemayq
a6662b73f5 add readme for dataset
Former-commit-id: bdcb0ea40e726e4c5752f938b379ed9a18e7e1d0
2023-08-23 19:55:45 +08:00
codemayq
cbc7db3478 add dataset stage and filter dataset when stage chosen in webui
Former-commit-id: 26e4136449a4df6028d834fd16a0f4a7c532759d
2023-08-23 18:54:23 +08:00
hiyouga
4606340f0f fix webui
Former-commit-id: 95304b6822d9fe04bcddc1ee246a56389bd5f96a
2023-08-23 11:03:35 +08:00
hoshi-hiyouga
d4b4ccd597 Merge pull request #644 from hiyouga/fix-quantization_bit
fix quantization bit is ""

Former-commit-id: e1a8eca182e532b48e472919b4474656a726b40c
2023-08-23 10:45:45 +08:00
codemayq
9c3f4e3a37 fix quantization bit is ""
Former-commit-id: 0dcab66f8843e2887f9f7ca66334122fef35c5b7
2023-08-23 10:08:17 +08:00
codemayq
440e00d8f9 fix quantization is ""
Former-commit-id: 2469cc16d1dd3f5ee822edc18b2d7021ff7cba03
2023-08-23 10:04:03 +08:00
hiyouga
6310613699 update template
Former-commit-id: a95f3a4d62de1073a78125401cf4289ec0523156
2023-08-22 19:46:09 +08:00
hoshi-hiyouga
f55907dbea Merge pull request #629 from panpan0000/main
add rm dataset explanation

Former-commit-id: c2b4571d0ffb6298d6e07212982d9c13efd65adf
2023-08-22 13:41:44 +08:00
Peter Pan
5cac87d317 add rm dataset explanation
Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>

Former-commit-id: 1efb95025be6501f1b30b20e7c711d3590b5d1ee
2023-08-22 01:33:59 -04:00
hoshi-hiyouga
9c0622de13 Merge pull request #619 from hiyouga/feature-templateTest
add template encode test

Former-commit-id: 8a1587ae49fff3968e0182f4fcc9a65dfdb260fc
2023-08-21 20:56:34 +08:00
codemayq
37b93c8b71 add template encode test
Former-commit-id: c15e0d6847cbc055d8376b3c43ac4fbd17b5877a
2023-08-21 20:51:24 +08:00
hiyouga
d6be98cda6 fix #617
Former-commit-id: a7bdaf1c92c7d798caf8438dc42a8972632ec584
2023-08-21 18:16:11 +08:00
hiyouga
4d128acc17 fix #608
Former-commit-id: c02a6809124fcfd06628c49c95d419ec2d8cc8ef
2023-08-21 17:49:36 +08:00
hiyouga
516df9ecce fix baichuan template for training #597 #616
Former-commit-id: 6530c1d972301eac9ef058b3235618bb09833f15
2023-08-21 17:41:51 +08:00
hiyouga
8eec1d50e1 fix #595
Former-commit-id: a360ccf9aa0484ce783eaa5857cf698b3ac2051e
2023-08-20 16:40:00 +08:00
hoshi-hiyouga
cfb096d43a Merge pull request #596 from beat4ocean/beat
fix KeyError: 'lang' bug

Former-commit-id: dd22541cdf1b832d20bb894d78c034afce841bfb
2023-08-20 16:37:40 +08:00
beat4ocean
713fa28804 fix KeyError: 'lang' bug
Former-commit-id: 4d4d9172b1f362cb4876315f1f5739e417055065
2023-08-20 15:32:36 +08:00
hiyouga
5549f35939 fix ppo trainer #551
Former-commit-id: 050a5447c191b8c50a0826a0f03bae499bff8b48
2023-08-20 14:07:11 +08:00
hiyouga
6eed1db36c Release v0.1.7
Former-commit-id: 81abe8d6cabaa1ebe74dc32a5dc143389e4c9f31
2023-08-18 17:21:27 +08:00
hiyouga
948124f55e tiny fix
Former-commit-id: 0ee159654ac6339c162745b004e2152ba6fe3c81
2023-08-18 13:07:35 +08:00
hiyouga
2b191ca776 support ppo score norm (trl 0.5.1.dev required)
Former-commit-id: 2b25db6d260ec1532281a592e873579346c7d21c
2023-08-18 12:02:42 +08:00
hiyouga
be4d2822ea fix PPO trainer #551 , update readme
Former-commit-id: faead74849470cebae9e37cde5fab2a71b32aa43
2023-08-18 11:43:10 +08:00
hiyouga
736ddd0319 update readme
Former-commit-id: beaf2fb737dbe64d35334d88b42935c89ef09eee
2023-08-18 01:51:55 +08:00
hiyouga
dfa289aa72 Update .gitignore
Former-commit-id: a1772a4dfef8dfaf7c2c321fad0a70ccf95fe6a0
2023-08-18 01:43:42 +08:00
hiyouga
c2644f939a update training resuming
Former-commit-id: 2ec75c31f609e65116ac3b621eeb7d8ccbf69135
2023-08-18 01:41:17 +08:00
hoshi-hiyouga
f11c1ae562 Merge pull request #434 from niuba/main
add last_checkpoint support

Former-commit-id: b78d461f2826c194c332ead37825704c2cb8b910
2023-08-18 01:38:31 +08:00
hoshi-hiyouga
3126164aa6 Merge branch 'main' into main
Former-commit-id: 870d2c7bf74d0da5a927bef4b8b01d15cc66a3e9
2023-08-18 01:37:23 +08:00
hiyouga
ed10486cad support bf16 ppo #551
Former-commit-id: 092088967de7409a2d51847cfc7afc83a8887320
2023-08-18 00:40:32 +08:00
hiyouga
04fa430c6c fix ChatGLM2 ppo #527 #528
Former-commit-id: 60d6ad64d7c9f6445b0df8de0153c3a311974198
2023-08-18 00:34:59 +08:00
hiyouga
fa1893b59c fix generation bug #532
Former-commit-id: c071121e67374e5f09798db57cfc8668617a36ae
2023-08-17 22:21:34 +08:00
hiyouga
e993e717a5 fix streaming in pt stage #548 #549
Former-commit-id: 050e992bee2a9293cc7399b578de807b5bf9bddc
2023-08-17 17:59:26 +08:00
hiyouga
c80e56423a update readme
Former-commit-id: b74af3c9cf29e1690ae4d5acb27599b1abd152e2
2023-08-17 11:00:22 +08:00
hiyouga
ffa09a01d6 fix baichuan and intern template
Former-commit-id: e1fd18fa6ef1009f978aca5210a259251a0b19a6
2023-08-17 01:27:20 +08:00
hiyouga
7d04f8567b fix generation
Former-commit-id: 66a0300d312ef91c24fcf80667fa3b0bb8e1a342
2023-08-16 22:39:54 +08:00
hiyouga
baa709674f fix system prompt
Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
2023-08-16 01:35:52 +08:00
hiyouga
ca9a494d0c fix baichuan template #481
Former-commit-id: 7608c6c25877d97ef26a1c209c4073c9c42f4535
2023-08-15 11:38:21 +08:00
hoshi-hiyouga
37eb8c05cc Merge pull request #516 from liuyanyi/add_gitignore
[Enhance] Add .gitignore file

Former-commit-id: 12cfe5482f5ef95d8c386d0af0de381e72eab0f9
2023-08-15 11:25:40 +08:00
hiyouga
7c046edb7b fix ChatGLM RLHF
Former-commit-id: 4e43e887e432ceb7e9287b4e309b63af3c3ba1bf
2023-08-15 11:19:20 +08:00
Yanyi Liu
22cea38b20 Add .gitignore
Former-commit-id: a2ebdeef81706596617da4409fc5da71739bccdc
2023-08-15 11:13:45 +08:00
hiyouga
ef2ca0a827 alert pad_token source
Former-commit-id: f26a84e0d927d2554890daf431a93652e18f4235
2023-08-15 00:07:56 +08:00
hiyouga
7f0b908de2 update webui
Former-commit-id: da30d0fb4abdb825f3383ddd106bb06a84695b7a
2023-08-14 22:45:26 +08:00
hoshi-hiyouga
5fc5e776ff Merge pull request #511 from hiyouga/feature-autoTemplate
add template match and stage in webui

Former-commit-id: 413752ecba845cddaff5fb48db7d3d24b960eec1
2023-08-14 22:44:04 +08:00
codemayq
93b281c016 auto match template when change model_name
Former-commit-id: ab2d7ab0572765ce33a52ac71641062d5d904db4
2023-08-14 20:56:05 +08:00
codemayq
9585699918 add template match and stage in webui
Former-commit-id: d6283e7f041f08f76d18350cb5f6a6c58ca80e92
2023-08-14 20:42:59 +08:00
hiyouga
bceaba551d fix ChatGLM lm_head #494
Former-commit-id: bf0048abdaeb2b9592d38ac991704ad014370b47
2023-08-14 14:14:48 +08:00
hiyouga
0bfeed3a7e fix bug in webui
Former-commit-id: c95f0f687689934379b6c24abf872ffcde06073b
2023-08-14 11:38:42 +08:00
hiyouga
70a780c3c0 fix webui cache
Former-commit-id: 9aba5c197fbc8abaab77f454374f8b497f0310d0
2023-08-14 11:37:01 +08:00
hiyouga
d74ab5306c update readme_zh
Former-commit-id: bdfe7e0285fdeb3a2728669dbdabf70c9652735c
2023-08-14 11:13:25 +08:00
hiyouga
688e8601ab web UI integrating RLHF
Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
2023-08-14 10:48:47 +08:00
hiyouga
4933ab5956 fix #480
Former-commit-id: ec15ca8fffacba2c34e1849c5ce90ca9989d66a2
2023-08-14 00:23:56 +08:00
hiyouga
6c7225a5d4 fix webui
Former-commit-id: 2c8b7414be9b43e20cc1d0575cc4dc1c7545fd86
2023-08-12 23:52:07 +08:00
hiyouga
a22982f2fa tiny fix
Former-commit-id: 50a34c043de6d9e1410291e1d8c1ea9d53754e9e
2023-08-12 22:02:43 +08:00
hiyouga
c95479dddb fix rope scaling
Former-commit-id: 2e0dd36700ec5e8294581c1db4b9431f755fc5f8
2023-08-12 22:00:01 +08:00
hiyouga
fc48bd8da0 update readme
Former-commit-id: 94ac570cb62aa9cd5dba105f0bb4c4da43eca042
2023-08-12 21:29:06 +08:00
hiyouga
d5323bfa3f update readme
Former-commit-id: ecfe87f34b383901f8e97ffb90af459cd55419b1
2023-08-12 21:25:19 +08:00
hiyouga
e9d4a2b507 update readme
Former-commit-id: eadbe9b7a0b6c8897e7a763b519cc5b7e00f3b2c
2023-08-12 21:23:05 +08:00
hiyouga
37bcbe8046 update readme
Former-commit-id: 6fa381400c21fa249cebcdff8c3afd72f8de20b3
2023-08-12 21:00:11 +08:00
hiyouga
fdfb644f0a support rope scaling, fix #475 #476 #478
Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
2023-08-12 20:46:27 +08:00
hoshi-hiyouga
cde9f3db57 Merge pull request #479 from hiyouga/feature-addCmdExport
add sft script preview in webui

Former-commit-id: 060225e57d13d8164beb6920410c181fbb28b77a
2023-08-12 20:41:52 +08:00
codemayq
8bf5a98815 add sft script preview in webui
Former-commit-id: 2b72649b404750226aa418b61ef5a6c9ac03938f
2023-08-12 13:53:55 +08:00
hiyouga
be566a15a5 fix unusual output of 8bit models #278 #391
Former-commit-id: 337ce5272b81f5561162beb08814b0e5abf23703
2023-08-12 00:25:29 +08:00
hiyouga
d5f1b99ac4 Release v0.1.6
Former-commit-id: 43c8b3c3c8bfb2e32d17fb3e8b194938e37d54bd
2023-08-11 23:25:57 +08:00
hiyouga
2144bb0e27 Update README_zh.md
Former-commit-id: 4fc154bcf039ba3f9240213158df757881cf3579
2023-08-11 14:06:02 +08:00
hiyouga
bc665bacc7 add defaults
Former-commit-id: 4636d3bbe6b984ca93e3a80ae5239f3ddda461bd
2023-08-11 13:56:26 +08:00
hiyouga
52bfcf4883 fix stop word in baichuan template
Former-commit-id: cba5ac9cfc81f11b97831998ea15def5e0b487c2
2023-08-11 13:51:46 +08:00
hiyouga
06df3d6fb6 fix baichuan template
Former-commit-id: b1681fe35346381cda613297f1cbb710f0a6daa6
2023-08-11 13:45:47 +08:00
hiyouga
ca719a8697 support DPO training (2305.18290)
Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
2023-08-11 03:02:53 +08:00
hoshi-hiyouga
72dfd74005 Merge pull request #451 from jovialchen/main
huggingface login for projects must login while running

Former-commit-id: 246ac241277908909b81cdf85fec1f24449dbae9
2023-08-10 17:25:38 +08:00
hiyouga
69302c4420 fix webui val size
Former-commit-id: 490c067d4e0828832e0ebdb704a9207dc974b15b
2023-08-10 15:20:44 +08:00
jiongxuc
42d7019b2e huggingface login for projects must login while running
Former-commit-id: 0a4a2a1d3e0ff1f57215512d294d782080bd383c
2023-08-10 14:57:12 +08:00
hiyouga
5f0d0d6b9b fix template
Former-commit-id: e3967eb1cdd8d19e8afee9ba52e7eb7d6cd86129
2023-08-09 23:14:27 +08:00
hiyouga
76cb63e4f6 fix template
Former-commit-id: 907e8cd86fbd4cdfa26dad21ceaf6e01d8fe37e4
2023-08-09 23:10:20 +08:00
hiyouga
467d571206 support val set in streaming mode
Former-commit-id: faed15b58ed00b1e09bb091e7eee48f5ef7c508b
2023-08-09 23:00:26 +08:00
hiyouga
972bfa700a fix tokenizer
Former-commit-id: 7849587cd4e149291d08edef9a528a1bad796c7e
2023-08-09 17:52:15 +08:00
niuba
458955d0fb add last_checkpoint support
Former-commit-id: 9f1977e4de00b14a9d1b555c25bcaf12998d5046
2023-08-09 16:39:27 +08:00
hiyouga
990eeccf45 fix sft trainer
Former-commit-id: 08cc888b1569572d0cd20bcf3f07e20072a0311a
2023-08-09 16:35:03 +08:00
hiyouga
a3a7465f00 fix rm #420, fix template #426, fix #423
Former-commit-id: 70ea3caaa7a7695c77179cd1bb18707a80a373d7
2023-08-09 16:23:31 +08:00
hoshi-hiyouga
031a819257 fix llama2 template
Former-commit-id: 6c74f726d4e672f5a1a57df201c27c1f697384f0
2023-08-09 00:58:27 +08:00
hoshi-hiyouga
eb4b4e3c8c fix tokenizer
Former-commit-id: fa463ef279b596d5d53cc169831f51b42031fc05
2023-08-09 00:54:54 +08:00
hiyouga
d2e1fe9b1d update webui
Former-commit-id: 343a4cd82b07a40f96ba413d1d991419ff07a24a
2023-08-09 00:26:11 +08:00
hiyouga
6e27a9e39a fix tokenizer #417
Former-commit-id: 01aa678311bfd213a4b410a4e0ff09f48a0d40a1
2023-08-08 23:59:41 +08:00
hiyouga
805478c911 fix bug
Former-commit-id: 0dff1d951f1a9fe05a74d334bf477b55c7c64199
2023-08-08 21:28:28 +08:00
hiyouga
a281cdeb89 fix bug
Former-commit-id: c13ce66021b21e015871b84489eeafa127a424a4
2023-08-08 17:55:55 +08:00
hiyouga
cda698a67f fix chatml template #408
Former-commit-id: 21e0cc3f44c35ae689b00b274391492f413725ac
2023-08-08 17:44:39 +08:00
hiyouga
15acd17716 update args spec
Former-commit-id: a006068346edda6e2851b23d2005fdb218a7287d
2023-08-07 15:23:35 +08:00
hiyouga
34a2bddfcd update readme
Former-commit-id: 06bcbb901f69265632892a5fcbc956b8be1153da
2023-08-07 15:02:02 +08:00
hiyouga
370f817549 Merge branch 'main' of https://github.com/hiyouga/LLaMA-Efficient-Tuning
Former-commit-id: 5c5657227db285048e3850631badb040eea9b6ca
2023-08-07 13:59:16 +08:00
hiyouga
041390c37e fix #376
Former-commit-id: a5b01257ba3323bcb2dd0217fb89a387e39ddbec
2023-08-07 13:58:59 +08:00
hoshi-hiyouga
d9fe4bf500 Merge pull request #382 from hiyouga/feature-updateReadme
add detailed model configs

Former-commit-id: 371c50cf3fd4e3f5e8fb390508c27cb5f18fa531
2023-08-07 13:43:38 +08:00
hiyouga
e0c7e944fc update trainer
Former-commit-id: 0d39b53a5164e34d22fe0a492eaa0d7ac63102fe
2023-08-07 13:34:35 +08:00
codemayq
0845fe67db add detailed model configs
Former-commit-id: 438c43f820e39738eaa1c296aadcf6d141c3289f
2023-08-07 09:30:23 +08:00
hiyouga
fe3b12d900 fix qwen eos token
Former-commit-id: 770830c67886f5872b39b9608949ec62d4616b27
2023-08-06 13:31:17 +08:00
hiyouga
a70d56864e fix qwen tokenizer #361
Former-commit-id: 78a2fa95c8ab669254a6c8fce8138c4395fb0a09
2023-08-05 17:06:05 +08:00
hiyouga
fdbb2c5378 fix template for tiktoken
Former-commit-id: 8328447f81eb5b90310df08cf2928c83ef6355fe
2023-08-05 13:42:42 +08:00
hiyouga
3c0aaf42af remove redundant code
Former-commit-id: dcec1717592107ba9d26eb2ac520309da19d1805
2023-08-05 00:27:27 +08:00
hiyouga
438e19160a fix template
Former-commit-id: b88200a88ea112e043dc44058606805c60e32844
2023-08-05 00:25:00 +08:00
hiyouga
f2b2ff6950 fix llama2 template
Former-commit-id: 08f37145e0bca5f1a8fd7bad01c64dc69b07361b
2023-08-05 00:07:54 +08:00
hoshi-hiyouga
86cef96305 Support safe ChatML template, fix qwen tok #351 #354
https://github.com/openai/openai-python/blob/main/chatml.md
Former-commit-id: 94bfc9d85f7cef3a5eb15085e0124a424373814f
2023-08-05 00:00:23 +08:00
hiyouga
5f50944baf fix bos and eos token
Former-commit-id: ab386f4c0fb5eaac24264a5bbef4c03deeb92158
2023-08-04 23:55:57 +08:00
hiyouga
0804fd2353 fix encode
Former-commit-id: ec382abd906d93cf78c7fbaec753ce6bcf8cfebd
2023-08-04 23:27:55 +08:00
hiyouga
86419eb457 support chatml safe encoding
Former-commit-id: ea52bb135bf9d07738091006ec7ada8df14cf15e
2023-08-04 23:14:28 +08:00
hiyouga
76f3ae7bf3 support interleave probs
Former-commit-id: 168d99816f9bdc746c587f7f09753ba7e0a4b19d
2023-08-04 21:27:35 +08:00
hiyouga
aaa85190eb fix webui export model
Former-commit-id: c34469c05e681239db23e2e666b5ac6a4e38aba9
2023-08-04 14:20:27 +08:00
hiyouga
e2a4e926b9 fix mtloader
Former-commit-id: ca48c2c02c3cfa9afb99971b50daeda9cf14e7cb
2023-08-03 19:29:02 +08:00
hiyouga
d6e922dc1c tiny fix
Former-commit-id: 81ef7017a4c96441951adeff0276cc5ab76a3544
2023-08-03 17:42:28 +08:00
hiyouga
27f4317ec6 fix qwen inference
Former-commit-id: 823f0de0ca0a92b6f48a90e5ffe57a48dc018f1d
2023-08-03 16:31:55 +08:00
hiyouga
e434348216 fix qwen inference
Former-commit-id: 2c5fe45ce1405124f12ecd20e263b5538af97972
2023-08-03 16:15:38 +08:00
hiyouga
2e19afedb8 support Qwen-7B, fix InternLM-7B inference
Former-commit-id: 25d2ca29ecb70cbfd5206333c667042a0c4d2e5a
2023-08-03 15:53:32 +08:00
hiyouga
da08fa7c63 update web demo
Former-commit-id: 5b6ad9adb665096bfb36dc90789a1d4a16345122
2023-08-03 13:28:28 +08:00
hiyouga
9c96b97dc7 fix webui
Former-commit-id: e87630ef77977b2879f1199b9a421acbbbb32a51
2023-08-03 12:43:12 +08:00
hiyouga
28a51b622b modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
2023-08-02 23:17:36 +08:00
hiyouga
8bd1da7144 fix PPO trainer
Former-commit-id: 21982a7d4dd9b7c3a1145b481f02b9990e32dc00
2023-08-02 19:10:23 +08:00
hiyouga
e4d0b8ee6e update ppo trainer
Former-commit-id: c27136a83e167465d3f825e40f10c7b9fcfbf97a
2023-08-02 18:46:41 +08:00
hiyouga
1dfb28b362 fix memory leak of PPO trainer
Former-commit-id: 38410894a5ebf0b043b55a6bd5cca3cd0a44b27d
2023-08-02 17:41:34 +08:00
hiyouga
ba618947e7 release v0.1.5
Former-commit-id: d619e76bc4098c29a7fdc05f5a71208bd1079c9f
2023-08-02 16:10:31 +08:00
hoshi-hiyouga
f81041b502 Merge pull request #307 from GitYCC/feature/fix-llama2-prompt-template
[feature] Fix template of Llama2 to match the offical template

Former-commit-id: a750b1f1ed16e20233df4d2f1c20507122919f5a
2023-08-02 15:51:28 +08:00
YC Chen
f2533a2800 [fix] Remove useless code
Former-commit-id: 077e1556112913e4eeef47e581055183b39d5404
2023-08-02 14:35:35 +08:00
YC Chen
bb5b4a7f26 [feature] Fix template of Llama2 to match the offical template
Former-commit-id: 1a98d45aefd95eea3768fb93e5a9da257ec61181
2023-08-02 14:10:15 +08:00
hiyouga
20bff87021 fix bug in preprocessing
Former-commit-id: 94952894576dfc4b42118162aec9aa35c3503c40
2023-08-02 01:10:28 +08:00
hiyouga
722b954800 update readme
Former-commit-id: 5154a04869be8c47e591351565b7842339fb99e4
2023-08-01 18:48:27 +08:00
hiyouga
19256086c7 fix #296
Former-commit-id: 69e9ed9b96a7cfb3d3b43ec5ddd01aa0bfd9b784
2023-08-01 18:43:53 +08:00
hiyouga
250fecfcd4 Fix #294
Former-commit-id: 09762d9849655f5e6c71b9472d55b42489dd944b
2023-08-01 18:13:03 +08:00
hiyouga
cb4d1d5ebb restore from git lfs
Former-commit-id: 0c734a37113b773ae7c0bc8b8d1af39b15bc0fb2
2023-08-01 16:33:25 +08:00
hiyouga
d7d557fb2e Update .gitattributes
Former-commit-id: 92e68f9f30c2fc91ae1b40865bc5c2d94899ba22
2023-08-01 16:28:54 +08:00
hiyouga
0b8e19b6a6 fix webui
Former-commit-id: cf4cd52d36894f53a6ec45d003f887771012e5b4
2023-08-01 12:11:37 +08:00
hiyouga
8e26eb374e fix RM save model
Former-commit-id: 8104cc2425431eb1cddccf3909855296116f922b
2023-08-01 11:56:17 +08:00
hiyouga
9bba01a033 use git lfs
Former-commit-id: 4886d0071751f68c5a2d926bd9fcee0c93337322
2023-08-01 10:14:08 +08:00
hiyouga
661890b8a1 release v0.1.4
Former-commit-id: 81f84aaf2e120e39edb28ef42893939fc9a184e2
2023-08-01 10:08:47 +08:00
hiyouga
772ad4ec6b fix inference
Former-commit-id: 55dc2bdd3eaa552c655e584fc3cbbf017c7bc3e7
2023-08-01 00:06:48 +08:00
hiyouga
6f65f8cb3b fix arg check
Former-commit-id: 2c5c73de9ebc88e2d04e80754781c94a571133a0
2023-07-31 23:48:57 +08:00
hiyouga
43e83548b9 update readme
Former-commit-id: d99cda254e5025ff3f968d256197ab031bfabef1
2023-07-31 23:42:32 +08:00
hiyouga
dd3f3e9749 support streaming data, fix #284 #274 #268
Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
2023-07-31 23:33:00 +08:00
hiyouga
124f61b404 Update data_args.py
Former-commit-id: 41ac5455af195747ba369c3a6dc7d412a366d54d
2023-07-28 17:42:41 +08:00
hiyouga
e8748cc6f3 update readme
Former-commit-id: 14d20cd1fdcfd1f2842362f70472b666e5d48c7d
2023-07-28 17:36:00 +08:00
hiyouga
fafec8b7a5 fix #268
Former-commit-id: 1eee0207fb370bb9e234e9bd3f9a0c47d7d01bc9
2023-07-28 17:02:26 +08:00
hiyouga
030daca686 update dataset
Former-commit-id: 4a044aabbd19c92a9ae93c1c30536f5086fd47f9
2023-07-26 17:05:12 +08:00
hiyouga
ac587438f8 fix #242
Former-commit-id: 80a346e29beb49e8935b786e2af1059fdc4954b2
2023-07-25 17:04:02 +08:00
hiyouga
c145bbef3c update dataset
Former-commit-id: 4fc2c3293d91d8464527ebd1ddabe572c8355616
2023-07-23 20:01:43 +08:00
hiyouga
745c46ee04 Update README_zh.md
Former-commit-id: 9d3c8803a34c06a2a5512fec3f841d7efcab3e3c
2023-07-22 14:31:16 +08:00
hiyouga
a707f5b502 update readme, fix web ui postprocess
Former-commit-id: ba51ab3379100108f7b52a3c2444ccdd99e8a6ef
2023-07-22 14:29:22 +08:00
hoshi-hiyouga
dc2e801077 Merge pull request #221 from mrhan1993/main
根据GLM Efficient Tuning添加中文README,web添加了server_port参数

Former-commit-id: 948f2abcb818211ee99d4c140e26044ca591369f
2023-07-22 13:04:25 +08:00
NULL
b56d5108b2 Merge branch 'hiyouga:main' into main
Former-commit-id: 8244c5b554ad0823e8ebea3d5583a6ecf9a66d2d
2023-07-21 17:00:26 +08:00
mrhan1993
8e6b7034fe 根据GLM Efficient Tuning添加中文README,web添加了server_port
Former-commit-id: 29e3acd23eafd891667d7a860ec544a5b05d3c33
2023-07-21 16:57:58 +08:00
hiyouga
dad7ca6633 release v0.1.3
Former-commit-id: 62c68bcbf591516e8f90b47810bea6f710fd23f6
2023-07-21 16:48:34 +08:00
hiyouga
a1468139a5 fix save function
Former-commit-id: 1d6beb0c8490a7531ffdf7a2819410597b200d12
2023-07-21 14:09:07 +08:00
hiyouga
49c90044ce Update runner.py
Former-commit-id: d7309deae46cfcdeeee79f54736df9b7e93b79ce
2023-07-21 13:35:19 +08:00
hiyouga
0f7cdac207 update web UI, support rm predict #210
Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
2023-07-21 13:27:27 +08:00
hiyouga
c4e9694c6e release v0.1.2
Former-commit-id: 04aad91b71cc3a1acaf1bcec4304ce6b2098f7dc
2023-07-20 22:33:59 +08:00
hiyouga
2006a96570 fix api
Former-commit-id: 4c3e8be325045e432b31c519132123c7b0689262
2023-07-20 22:14:54 +08:00
hoshi-hiyouga
5dcd95645f Merge pull request #213 from Ehco1996/patch-1
feat: support pass args before init web app
Former-commit-id: b0612c05bc10c281c0a95e08c5517c3fb0a72029
2023-07-20 22:12:07 +08:00
hiyouga
9b3304b054 update UI, fix #212
Former-commit-id: ac92c2bd7c47353759474fad9412f21b38c65501
2023-07-20 22:09:06 +08:00
Ehco
e580d4ef41 feat: support pass args before init web app
as title

Former-commit-id: 434a5077288927e0be15cd066ca3e562111fad4d
2023-07-20 21:49:26 +08:00
hiyouga
64db4abc68 Update README.md
Former-commit-id: 6dc67a495ec7d9fdc2574bae92063ed8a9099725
2023-07-20 17:23:16 +08:00
hiyouga
5ba0b80e5c simplify code
Former-commit-id: d3731754ab7c28ae81f60784e0e4213f279d93fe
2023-07-20 15:08:57 +08:00
hiyouga
7a43ff3d89 tiny fix
Former-commit-id: 22b1be7bbb9e7bd863acb88bf7365090b1b8235d
2023-07-19 22:53:46 +08:00
hiyouga
7e1a1d141a fix #199
Former-commit-id: 7fc778b49bc17688aca39fffe01f9d33e03e0c28
2023-07-19 22:51:29 +08:00
hiyouga
6d881f161b add datasets
Former-commit-id: 02e4b47dea1b25905c61f2ace88bab112610f021
2023-07-19 20:59:15 +08:00
hiyouga
a02b3e6192 fix #196
Former-commit-id: 85fd82926db345a590a7fb32c0e352a1d2f025c3
2023-07-19 17:35:38 +08:00
hiyouga
bcdee9fc19 fix #194
Former-commit-id: 9792921531efefb4bcddbde4380169a78fe064a6
2023-07-19 17:07:33 +08:00
hiyouga
8b688251be support LLaMA-2
Former-commit-id: 04dfda054855ee9256586aacbd382f8fb0bfed04
2023-07-19 16:42:14 +08:00
hiyouga
718f3382ad add LLaMA2 template
Former-commit-id: 246421bd35cf7bb2203ac4fc924e6cd1c292954d
2023-07-19 00:44:49 +08:00
hiyouga
dc8283d3d7 fix API
Former-commit-id: 9b10c9a12e33ab897056ecc61d977d221c19141b
2023-07-19 00:01:14 +08:00
130 changed files with 8066 additions and 3849 deletions

160
.gitignore vendored Normal file
View File

@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

128
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
`hoshihiyouga AT gmail DOT com`.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

423
README.md
View File

@@ -1,92 +1,156 @@
# LLaMA Efficient Tuning
# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Efficient-Tuning?style=social)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Efficient-Tuning)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
👋 Join our [WeChat](assets/wechat.jpg).
\[ English | [中文](README_zh.md) \]
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)**.
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
## Changelog
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model.
[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
[23/07/31] We supported **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model. If you want to train with RTX3090, use `git checkout baichuan-7b-rtx3090` to switch to the `baichuan-7b-rtx3090` branch and try the `--baichuan_rtx_gpu true` argument. (Other RTX series GPUs can also be tried)
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
- [InternLM](https://github.com/InternLM/InternLM) (7B)
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
> [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
## Supported Training Approaches
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> Use `--quantization_bit 4/8` argument to enable QLoRA.
## Provided Datasets
- For pre-training:
- [Wiki Demo](data/wiki_demo.txt)
- For supervised fine-tuning:
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat](https://github.com/thunlp/UltraChat)
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [WebNovel (Chinese)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward model training:
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
<details><summary>Pre-training datasets</summary>
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
</details>
<details><summary>Supervised fine-tuning datasets</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
</details>
<details><summary>Preference datasets</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
</details>
Please refer to [data/README.md](data/README.md) for details.
Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands.
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
```bash
pip install --upgrade huggingface_hub
@@ -97,51 +161,53 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- jieba, rouge-chinese and nltk (used at evaluation)
- gradio and matplotlib (used in web_demo.py)
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
- sentencepiece, protobuf and tiktoken
- jieba, rouge-chinese and nltk (used at evaluation and predict)
- gradio and matplotlib (used in web UI)
- uvicorn, fastapi and sse-starlette (used in API)
And **powerful GPUs**!
If you want to enable quantized LoRA (QLoRA) on the Windows platform, you should install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
## Getting Started
### Data Preparation (optional)
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
Note: please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
> [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
### Dependence Installation (optional)
```bash
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
cd LLaMA-Efficient-Tuning
git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd LLaMA-Factory
pip install -r requirements.txt
```
### All-in-one Web UI
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
python src/train_web.py
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### (Continually) Pre-Training
### Train on a single GPU
> [!IMPORTANT]
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
#### Pre-Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset wiki_demo \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
@@ -155,15 +221,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
### Supervised Fine-Tuning
#### Supervised Fine-Tuning
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
@@ -177,36 +245,43 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
### Reward Model Training
#### Reward Modeling
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--learning_rate 1e-6 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### PPO Training (RLHF)
#### PPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
@@ -217,30 +292,51 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--plot_loss
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### Distributed Training
#### Use Huggingface Accelerate
```bash
accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above)
```
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
<details><summary>Example config for LoRA training</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
@@ -256,38 +352,76 @@ use_cpu: false
</details>
### Evaluation (BLEU and ROUGE_CHINESE)
#### Use DeepSpeed
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--model_name_or_path path_to_your_model \
--do_eval \
--dataset alpaca_gpt4_en \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 50 \
--predict_with_generate
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
... # arguments (same as above)
```
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
<details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### Export model
```bash
python src/export_model.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--export_dir path_to_export
```
### API Demo
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
See `http://localhost:8000/docs` for API documentation.
> [!NOTE]
> Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@@ -295,48 +429,77 @@ python src/cli_demo.py \
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Export model
### Evaluation
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
--template vanilla \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4
```
### Predict
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_llama_model \
--do_predict \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
> [!NOTE]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
## Projects using LLaMA Factory
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
## License
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights:
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
## Citation
If this work is helpful, please kindly cite as:
```bibtex
@Misc{llama-efficient-tuning,
title = {LLaMA Efficient Tuning},
@Misc{llama-factory,
title = {LLaMA Factory},
author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
year = {2023}
}
```
## Acknowledgement
This repo is a sibling of [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning). They share a similar code structure of efficient tuning on large language models.
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Efficient-Tuning&type=Date)
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)

504
README_zh.md Normal file
View File

@@ -0,0 +1,504 @@
# LLaMA Factory: 轻松的大模型训练与评估
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
👋 加入我们的[微信群](assets/wechat.jpg)。
\[ [English](README.md) | 中文 \]
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。该模式目前仅支持单卡训练
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
## 更新日志
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune例如 `--neft_alpha 5`
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2。
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming``--max_steps 10000` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft)。
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
## 模型
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
> [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
## 训练方法
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
## 数据集
<details><summary>预训练数据集</summary>
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
</details>
<details><summary>指令微调数据集</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
</details>
<details><summary>偏好数据集</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
</details>
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
```bash
pip install --upgrade huggingface_hub
huggingface-cli login
```
## 软件依赖
- Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- sentencepiece, protobuf 和 tiktoken
- jieba, rouge-chinese 和 nltk (用于评估及预测)
- gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API)
以及 **强而有力的 GPU**
## 如何使用
### 数据准备(可跳过)
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
> [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。
### 环境搭建(可跳过)
```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd LLaMA-Factory
pip install -r requirements.txt
```
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### 单 GPU 训练
> [!IMPORTANT]
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
#### 预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset wiki_demo \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 奖励模型训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-6 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
```
#### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate
```bash
accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上
```
<details><summary>LoRA 训练的 Accelerate 配置示例</summary>
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
#### 使用 DeepSpeed
```bash
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
... # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例</summary>
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### 导出微调后的完整模型
```bash
python src/export_model.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--export_dir path_to_export
```
### API 服务
```bash
python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
> [!NOTE]
> 关于 API 文档请见 `http://localhost:8000/docs`。
### 命令行测试
```bash
python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 浏览器测试
```bash
python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 模型评估
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task ceval \
--split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
```
### 模型预测
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_llama_model \
--do_predict \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
> [!NOTE]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
## 使用了 LLaMA Factory 的项目
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
## 引用
如果您觉得此项目有帮助,请考虑以下列格式引用
```bibtex
@Misc{llama-factory,
title = {LLaMA Factory},
author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
year = {2023}
}
```
## 致谢
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)

View File

@@ -1,53 +1,107 @@
Data format in `dataset_info.json`:
If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
```json
"dataset_name": {
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
"subset": "the name of the subset. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)",
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)",
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)",
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
"content": "the key in the message represents the content. (default: value, for sharegpt)"
}
}
```
`dataset_info.json` 中的数据集定义格式:
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
```json
"数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址若指定则忽略下列三个参数",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
[
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"output": "model response (required)",
"history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
```json
"dataset_name": {
"columns": {
"prompt": "数据集代表提示词的表头名称(默认:instruction",
"query": "数据集代表请求的表头名称(默认:input",
"response": "数据集代表回答的表头名称(默认:output",
"history": "数据集代表历史对话的表头名称默认None"
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
}
```
部分预置数据集简介:
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
| 数据集名称 | 规模 | 描述 |
| --- | --- | --- |
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly流萤的中文数据集包含多个 NLP 任务 |
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。
For the pre-training datasets, only the `prompt` column will be used for training.
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
```json
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
}
```
The dataset in sharegpt format should follow the below format:
```json
[
{
"conversations": [
{
"from": "human",
"value": "user instruction"
},
{
"from": "gpt",
"value": "model response"
}
]
}
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
```json
"dataset_name": {
"columns": {
"messages": "conversations",
"role": "from",
"content": "value"
}
}
```
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order.
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.

107
data/README_zh.md Normal file
View File

@@ -0,0 +1,107 @@
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中按照以下格式提供数据集定义。
```json
"数据集名称": {
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选留空不影响训练",
"subset": "数据集子集的名称可选默认None",
"ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction用于 alpaca 格式)",
"query": "数据集代表请求的表头名称默认input用于 alpaca 格式)",
"response": "数据集代表回答的表头名称默认output用于 alpaca 格式)",
"history": "数据集代表历史对话的表头名称默认None用于 alpaca 格式)",
"messages": "数据集代表消息列表的表头名称默认conversations用于 sharegpt 格式)",
"role": "消息中代表发送者身份的键名默认from用于 sharegpt 格式)",
"content": "消息中代表文本内容的键名默认value用于 sharegpt 格式)"
}
}
```
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
```json
[
{
"instruction": "用户指令(必填)",
"input": "用户输入(选填)",
"output": "模型回答(必填)",
"history": [
["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"]
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
```json
"数据集名称": {
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
}
```
其中 `prompt``response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
```json
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"优质回答",
"劣质回答"
]
}
```
而 sharegpt 格式的数据集按照以下方式组织:
```json
[
{
"conversations": [
{
"from": "human",
"value": "用户指令"
},
{
"from": "gpt",
"value": "模型回答"
}
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
```json
"数据集名称": {
"columns": {
"messages": "conversations",
"role": "from",
"content": "value"
}
}
```
其中 `messages` 列必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
预训练数据集和偏好数据集尚不支持 sharegpt 格式。

View File

@@ -1,6 +1,5 @@
import json
import datasets
from typing import Any, Dict, List
_DESCRIPTION = "BELLE multiturn chat dataset."
@@ -23,11 +22,9 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
def _info(self):
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
@@ -37,7 +34,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
@@ -48,10 +45,11 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
)
]
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat with history
def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
conversations = []
prompt = data["instruction"].strip()
response = data["output"].strip()
@@ -59,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip()
history = []
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:")
@@ -67,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
history.insert(0, (old_query, old_resp))
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else:
break
prompt = prompt[:human_idx].strip()
yield key, {
"instruction": query,
"output": response,
"history": history
}
yield key, {"conversations": conversations}

View File

@@ -3,7 +3,7 @@ import datasets
from typing import Any, Dict, List
_DESCRIPTION = "An example of dataset for LLaMA."
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""

View File

@@ -1,9 +1,9 @@
import json
import datasets
from typing import Any, Dict, List
from typing import List
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = ""
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit"
@@ -42,7 +42,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download_and_extract(_URLS)
return [
datasets.SplitGenerator(
@@ -59,7 +59,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
)
]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
def _generate_examples(self, filepaths: List[str]):
key = 0
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:

View File

@@ -1 +1 @@
0a57fbc1d8cb08a8cd71c5eb8425cf59206ffed6
57fd080be5bffe4153fe3ee26a175e3d56da30f3

View File

@@ -1,2 +0,0 @@
{"id": 0,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利David Clayton Henrie美国演员。近来在迪士尼频道原创电视影集《少年魔法师》Wizards of Waverly Place当中演出贾斯汀·鲁索Justin Russo一角。\n\n大卫·亨利出生在加州Mission Viejo在凤凰城长大。他的胞弟劳伦斯·亨利Lorenzo Henrie也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔Lucy Hale之后与其交往于2009年分手。\n\n10岁时大卫·亨利和SAG在凤凰城签订了合约并开始走出去试镜。 9岁的时候在沙加缅度进行商业拍摄SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天他和他的家人搬到了好莱坞。他预定他的前2支商业试镜扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁大卫有了他的第一次重大突破在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker和琳达布莱儿、乔治甘迺迪共同演出并要求回来Hallmark movie公司。 \n\n在18岁时大卫得到了迪士尼频道原创系列演出机会该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长隔年为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
{"id": 1,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利David Clayton Henrie美国演员。近来在迪士尼频道原创电视影集《少年魔法师》Wizards of Waverly Place当中演出贾斯汀·鲁索Justin Russo一角。\n\n大卫·亨利出生在加州Mission Viejo在凤凰城长大。他的胞弟劳伦斯·亨利Lorenzo Henrie也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔Lucy Hale之后与其交往于2009年分手。\n\n10岁时大卫·亨利和SAG在凤凰城签订了合约并开始走出去试镜。 9岁的时候在沙加缅度进行商业拍摄SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天他和他的家人搬到了好莱坞。他预定他的前2支商业试镜扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁大卫有了他的第一次重大突破在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker和琳达布莱儿、乔治甘迺迪共同演出并要求回来Hallmark movie公司。 \n\n在18岁时大卫得到了迪士尼频道原创系列演出机会该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长隔年为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}

View File

@@ -0,0 +1 @@
38c89869c6aeca2a3af9ea1e09afe460f9b46810

View File

@@ -1,6 +1,6 @@
import json
import datasets
from typing import Any, Dict, List
from typing import List
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
@@ -21,15 +21,13 @@ _LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
class BelleMultiturn(datasets.GeneratorBasedBuilder):
class UltraChat(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
def _info(self):
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
@@ -39,8 +37,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
@@ -50,7 +48,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
)
]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
for row in f:
@@ -58,19 +56,14 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
data = json.loads(row)
except:
continue
key = data["id"]
content = data["data"]
key: int = data["id"]
content: List[str] = data["data"]
if len(content) % 2 == 1:
content.pop(-1)
if len(content) < 2:
continue
query = content[-2]
response = content[-1]
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
yield key, {
"instruction": query,
"output": response,
"history": history
}
conversations = [{
"from": "human" if i % 2 == 0 else "gpt",
"value": content[i]
} for i in range(len(content))]
yield key, {"conversations": conversations}

166
evaluation/ceval/ceval.py Normal file
View File

@@ -0,0 +1,166 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{huang2023ceval,
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
journal={arXiv preprint arXiv:2305.08322},
year={2023}
}
"""
_DESCRIPTION = """\
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
"""
_HOMEPAGE = "https://cevalbenchmark.com"
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
_URL = "ceval.zip"
task_list = [
"computer_network",
"operating_system",
"computer_architecture",
"college_programming",
"college_physics",
"college_chemistry",
"advanced_mathematics",
"probability_and_statistics",
"discrete_mathematics",
"electrical_engineer",
"metrology_engineer",
"high_school_mathematics",
"high_school_physics",
"high_school_chemistry",
"high_school_biology",
"middle_school_mathematics",
"middle_school_biology",
"middle_school_physics",
"middle_school_chemistry",
"veterinary_medicine",
"college_economics",
"business_administration",
"marxism",
"mao_zedong_thought",
"education_science",
"teacher_qualification",
"high_school_politics",
"high_school_geography",
"middle_school_politics",
"middle_school_geography",
"modern_chinese_history",
"ideological_and_moral_cultivation",
"logic",
"law",
"chinese_language_and_literature",
"art_studies",
"professional_tour_guide",
"legal_professional",
"high_school_chinese",
"high_school_history",
"middle_school_history",
"civil_servant",
"sports_science",
"plant_protection",
"basic_medicine",
"clinical_medicine",
"urban_and_rural_planner",
"accountant",
"fire_engineer",
"environmental_impact_assessment_engineer",
"tax_accountant",
"physician",
]
class CevalConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
class Ceval(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CevalConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"id": datasets.Value("int32"),
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
"explanation": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(
data_dir, "test", f"{task_name}_test.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": os.path.join(
data_dir, "val", f"{task_name}_val.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(
data_dir, "dev", f"{task_name}_dev.csv"
),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath, encoding="utf-8")
for i, instance in enumerate(df.to_dict(orient="records")):
if "answer" not in instance.keys():
instance["answer"] = ""
if "explanation" not in instance.keys():
instance["explanation"] = ""
yield i, instance

167
evaluation/cmmlu/cmmlu.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
journal={arXiv preprint arXiv:2306.09212},
year={2023}
}
"""
_DESCRIPTION = """\
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
"""
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
_URL = "cmmlu.zip"
task_list = [
'agronomy',
'anatomy',
'ancient_chinese',
'arts',
'astronomy',
'business_ethics',
'chinese_civil_service_exam',
'chinese_driving_rule',
'chinese_food_culture',
'chinese_foreign_policy',
'chinese_history',
'chinese_literature',
'chinese_teacher_qualification',
'clinical_knowledge',
'college_actuarial_science',
'college_education',
'college_engineering_hydrology',
'college_law',
'college_mathematics',
'college_medical_statistics',
'college_medicine',
'computer_science',
'computer_security',
'conceptual_physics',
'construction_project_management',
'economics',
'education',
'electrical_engineering',
'elementary_chinese',
'elementary_commonsense',
'elementary_information_and_technology',
'elementary_mathematics',
'ethnology',
'food_science',
'genetics',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_geography',
'high_school_mathematics',
'high_school_physics',
'high_school_politics',
'human_sexuality',
'international_law',
'journalism',
'jurisprudence',
'legal_and_moral_basis',
'logical',
'machine_learning',
'management',
'marketing',
'marxist_theory',
'modern_chinese',
'nutrition',
'philosophy',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_study',
'sociology',
'sports_science',
'traditional_chinese_medicine',
'virology',
'world_history',
'world_religions',
]
class CMMLUConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.1"), **kwargs)
class CMMLU(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CMMLUConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(data_dir, f"test/{task_name}.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(data_dir, f"dev/{task_name}.csv"),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8")
for i, instance in enumerate(df.to_dict(orient="records")):
question = instance.pop("Question", "")
answer = instance.pop("Answer", "")
instance["question"] = question
instance["answer"] = answer
yield i, instance

167
evaluation/mmlu/mmlu.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2021}
}
"""
_DESCRIPTION = """\
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
"""
_HOMEPAGE = "https://github.com/hendrycks/test"
_LICENSE = "MIT"
_URL = "mmlu.zip"
task_list = [
"high_school_european_history",
"business_ethics",
"clinical_knowledge",
"medical_genetics",
"high_school_us_history",
"high_school_physics",
"high_school_world_history",
"virology",
"high_school_microeconomics",
"econometrics",
"college_computer_science",
"high_school_biology",
"abstract_algebra",
"professional_accounting",
"philosophy",
"professional_medicine",
"nutrition",
"global_facts",
"machine_learning",
"security_studies",
"public_relations",
"professional_psychology",
"prehistory",
"anatomy",
"human_sexuality",
"college_medicine",
"high_school_government_and_politics",
"college_chemistry",
"logical_fallacies",
"high_school_geography",
"elementary_mathematics",
"human_aging",
"college_mathematics",
"high_school_psychology",
"formal_logic",
"high_school_statistics",
"international_law",
"high_school_mathematics",
"high_school_computer_science",
"conceptual_physics",
"miscellaneous",
"high_school_chemistry",
"marketing",
"professional_law",
"management",
"college_physics",
"jurisprudence",
"world_religions",
"sociology",
"us_foreign_policy",
"high_school_macroeconomics",
"computer_security",
"moral_scenarios",
"moral_disputes",
"electrical_engineering",
"astronomy",
"college_biology",
]
class MMLUConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
class MMLU(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
MMLUConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "test", f"{task_name}_test.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "val", f"{task_name}_val.csv"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "dev", f"{task_name}_dev.csv"
),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath)
df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")):
yield i, instance

View File

@@ -1,14 +1,17 @@
torch>=1.13.1
transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.19.0
peft>=0.3.0
trl>=0.4.7
transformers>=4.31.0,<4.35.0
datasets>=2.14.0
accelerate>=0.21.0
peft>=0.6.0
trl>=0.7.4
gradio>=3.38.0,<4.0.0
scipy
sentencepiece
protobuf
tiktoken
jieba
rouge-chinese
nltk
gradio>=3.36.0
uvicorn
pydantic
fastapi

View File

@@ -25,12 +25,12 @@ def main():
version=get_version(),
author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use fine-tuning framework using PEFT",
description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Efficient-Tuning",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.8.0",

View File

@@ -1,13 +1,14 @@
# coding=utf-8
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
import uvicorn
from llmtuner import create_app
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
print("Visit http://localhost:8000/docs for API document.")
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
if __name__ == "__main__":
app = create_app()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
main()

View File

@@ -1,12 +1,16 @@
# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner import ChatModel, get_infer_args
try:
import platform
if platform.system() != "Windows":
import readline
except ImportError:
print("Install `readline` for a better experience.")
def main():
chat_model = ChatModel(*get_infer_args())
chat_model = ChatModel()
history = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
@@ -24,6 +28,7 @@ def main():
if query.strip() == "clear":
history = []
torch_gc()
print("History has been removed.")
continue

10
src/evaluate.py Normal file
View File

@@ -0,0 +1,10 @@
from llmtuner import Evaluator
def main():
evaluator = Evaluator()
evaluator.eval()
if __name__ == "__main__":
main()

View File

@@ -1,16 +1,8 @@
# coding=utf-8
# Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner import get_train_args, load_model_and_tokenizer
from llmtuner import export_model
def main():
model_args, _, training_args, finetuning_args, _ = get_train_args()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
export_model()
if __name__ == "__main__":

View File

@@ -1,7 +1,10 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
from llmtuner.webui import create_ui
from llmtuner.eval import Evaluator
from llmtuner.train import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.1.0"
__version__ = "0.3.0"

View File

@@ -1,14 +1,8 @@
import json
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple
from pydantic import BaseModel
from contextlib import asynccontextmanager
from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
@@ -23,17 +17,40 @@ from llmtuner.api.protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage
)
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
)
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
def create_app():
chat_model = ChatModel(*get_infer_args())
def to_json(data: BaseModel) -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@@ -49,57 +66,75 @@ def create_app():
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content
else:
prefix = None
system = None
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream:
generate = predict(query, history, prefix, request)
generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
responses = chat_model.chat(
query, history, system,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length+response_length
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason=Finish.STOP
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
yield to_json(chunk)
for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, system,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens
):
if len(new_text) == 0:
continue
@@ -110,7 +145,7 @@ def create_app():
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@@ -118,12 +153,13 @@ def create_app():
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
yield to_json(chunk)
yield "[DONE]"
return app
if __name__ == "__main__":
app = create_app()
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@@ -20,9 +20,6 @@ class ModelCard(BaseModel):
object: Optional[str] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = []
class ModelList(BaseModel):
@@ -43,6 +40,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
do_sample: Optional[bool] = True
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = 1

View File

@@ -1 +1 @@
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat.chat_model import ChatModel

View File

@@ -0,0 +1,132 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.tokenizer.padding_side = "left"
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt
def _process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
if max_new_tokens:
generating_args.pop("max_length", None)
generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor()
)
return gen_kwargs, prompt_length
@torch.inference_mode()
def chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length"
))
return results
@torch.inference_mode()
def stream_chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
yield from streamer

View File

@@ -1,85 +0,0 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import Template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer
class ChatModel:
def __init__(
self,
model_args: ModelArguments,
data_args: DataArguments,
finetuning_args: FinetuningArguments,
generating_args: GeneratingArguments
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.template = Template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
self.generating_args = generating_args
def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
))
if max_length:
gen_kwargs.pop("max_new_tokens", None)
gen_kwargs["max_length"] = max_length
if max_new_tokens:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = max_new_tokens
return gen_kwargs, prompt_length
@torch.inference_mode()
def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response_length = len(outputs)
return response, (prompt_length, response_length)
@torch.inference_mode()
def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
yield new_text

View File

@@ -0,0 +1,4 @@
from llmtuner.data.loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.data.utils import split_dataset

145
src/llmtuner/data/loader.py Normal file
View File

@@ -0,0 +1,145 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.data.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_files = None
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_files = None
elif dataset_attr.load_from == "file":
data_path, data_name = None, None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
dataset = load_dataset(
path=data_path,
name=data_name,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming
)
if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# convert dataset from sharegpt format to alpaca format
outputs = {"prompt": [], "query": [], "response": [], "history": []}
for msg_list in examples[dataset_attr.messages]:
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
if len(msg_list) == 0:
continue
msg_pairs = []
user_role, assistant_role = None, None
for idx in range(0, len(msg_list), 2):
if user_role is None and assistant_role is None:
user_role = msg_list[idx][dataset_attr.role]
assistant_role = msg_list[idx + 1][dataset_attr.role]
else:
if (
msg_list[idx][dataset_attr.role] != user_role
or msg_list[idx+1][dataset_attr.role] != assistant_role
):
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
if len(msg_pairs) != 0:
outputs["prompt"].append(msg_pairs[-1][0])
outputs["query"].append("")
outputs["response"].append(msg_pairs[-1][1])
outputs["history"].append(msg_pairs[:-1])
return outputs
if dataset_attr.formatting == "sharegpt": # convert format
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
dataset = dataset.map(
convert_format,
batched=True,
remove_columns=column_names,
**kwargs
)
else:
for column_name in ["prompt", "query", "response", "history"]: # align dataset
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.system_prompt: # add system prompt
system_prompt = dataset_attr.system_prompt
if data_args.streaming:
dataset = dataset.map(lambda _: {"system": system_prompt})
else:
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=data_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
)
else:
raise ValueError("Unknown mixing strategy.")

View File

@@ -0,0 +1,275 @@
import os
import tiktoken
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
from datasets import load_from_disk
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
# make sure the saved tokenizer is the same as the original one
if hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
source_len, target_len = len(source_ids), len(target_ids)
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
source_ids = source_ids[:max_source_len]
if target_len > max_target_len:
target_ids = target_ids[:max_target_len]
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
labels = labels[:data_args.cutoff_len]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and query != ""):
continue
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
if len(labels) > data_args.cutoff_len:
labels = labels[:data_args.cutoff_len]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
rejected_ids = rejected_ids[:max_target_len]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
return load_from_disk(data_args.cache_path)
with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Empty dataset!")
return dataset

View File

@@ -0,0 +1,682 @@
import tiktoken
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
@dataclass
class Template:
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
system: str
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
efficient_eos: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs
def _format(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[str, List[Tuple[str, str]]]:
r"""
Aligns inputs to the standard format.
"""
system = system or self.system # use system if provided
history = history if (history and self.use_history) else []
history = history + [(query, resp)]
return system, history
def _get_special_ids(
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
bos_ids = [tokenizer.bos_token_id]
else: # baichuan, qwen and gpt2 models have no bos token
bos_ids = []
if tokenizer.eos_token_id is None:
raise ValueError("EOS token is required.")
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
eos_ids = []
else:
eos_ids = [tokenizer.eos_token_id]
return bos_ids, eos_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
prefix_ids = bos_ids + prefix_ids + sep_ids
else:
prefix_ids = bos_ids
else:
prefix_ids = sep_ids + bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
) -> List[int]:
r"""
Converts context to token ids.
"""
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
token_ids = []
for elem in context:
if isinstance(elem, str):
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
if len(elem) != 0:
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
return token_ids
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has no sep_ids
query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
templates: Dict[str, Template] = {}
def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False
) -> None:
template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
system=system,
sep=sep,
stop_words=stop_words,
use_history=use_history,
efficient_eos=efficient_eos
)
def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None:
return None
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False
)
return template
register_template(
name="alpaca",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[
"\n\n"
]
)
register_template(
name="aquila",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}###Assistant:"
],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[
"###"
],
stop_words=[
"</s>"
],
efficient_eos=True
)
register_template(
name="baichuan",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_102>"}, # user token
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
efficient_eos=True
)
register_template(
name="baichuan2",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}",
{"token": "<reserved_107>"} # assistant token
],
system="",
sep=[],
efficient_eos=True
)
register_template(
name="belle",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
"\n\n"
]
)
register_template(
name="bluelm",
prefix=[
"{{system}}"
],
prompt=[
{"token": "[|Human|]:"},
"{{query}}",
{"token": "[|AI|]:"}
],
system="",
sep=[]
)
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"\n\n"
],
efficient_eos=True
)
register_template(
name="chatglm3",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
{"token": "<|user|>"},
"\n",
"{{query}}",
{"token": "<|assistant|>"}
],
system="",
sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
],
efficient_eos=True
)
register_template(
name="deepseek",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer."
),
sep=[
"\n",
{"token": "<|EOT|>"},
"\n\n"
],
stop_words=[
"<|EOT|>"
],
efficient_eos=True
)
register_template(
name="default",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant:"
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[
"\n"
]
)
register_template(
name="falcon",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nFalcon:"
],
system="",
sep=[
"\n"
],
efficient_eos=True
)
register_template(
name="intern",
prefix=[
"{{system}}"
],
prompt=[
"<|User|>:{{query}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
system="",
sep=[
{"token": "<eoa>"},
"\n"
],
stop_words=[
"<eoa>"
],
efficient_eos=True
)
register_template(
name="llama2",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
)
register_template(
name="mistral",
prefix=[
"{{system}}"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="",
sep=[]
)
register_template(
name="openchat",
prefix=[
"{{system}}"
],
prompt=[
"GPT4 Correct User: {{query}}",
{"token": "<|end_of_turn|>"},
"GPT4 Correct Assistant:"
],
system="",
sep=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
"<|end_of_turn|>"
],
efficient_eos=True
)
register_template(
name="qwen",
prefix=[
{"token": "<|im_start|>"},
"system\n{{system}}"
],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
system="You are a helpful assistant.",
sep=[
{"token": "<|im_end|>"},
"\n"
],
stop_words=[
"<|im_end|>"
],
efficient_eos=True
)
register_template(
name="starchat",
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
system="",
sep=[
{"token": "<|end|>"},
"\n"
],
stop_words=[
"<|end|>"
],
efficient_eos=True
)
r"""
Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix=[],
prompt=[
"{{query}}"
],
system="",
sep=[],
use_history=False
)
register_template(
name="vicuna",
prefix=[
"{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT:"
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
register_template(
name="xverse",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nAssistant: "
],
system="",
sep=[]
)
register_template(
name="yayi",
prefix=[
{"token": "<|System|>"},
":\n{{system}}"
],
prompt=[
{"token": "<|Human|>"},
":\n{{query}}\n\n",
{"token": "<|YaYi|>"},
":"
],
system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[
"\n\n"
],
stop_words=[
"<|End|>"
]
)
register_template(
name="zephyr",
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
{"token": "</s>"}
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "</s>"},
{"token": "<|assistant|>"}
],
system="You are a friendly chatbot who always responds in the style of a pirate",
sep=[]
)
register_template(
name="ziya",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
"\n"
]
)

View File

@@ -0,0 +1,61 @@
import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@@ -1,2 +0,0 @@
from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset

View File

@@ -1,63 +0,0 @@
import os
import json
import time
from datetime import timedelta
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.start_time = time.time()
self.tracker = {}
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
if "loss" not in state.log_history[-1]:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
self.tracker = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n")

View File

@@ -1,106 +0,0 @@
import os
import hashlib
from typing import List
from datasets import Dataset, concatenate_datasets, load_dataset
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
def get_dataset(
model_args: ModelArguments,
data_args: DataArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
ext2type = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name
data_files = None
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_files = None
elif dataset_attr.load_from == "file":
data_path = None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = ext2type.get(data_files[0].split(".")[-1], None)
else:
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = ext2type.get(data_files[0].split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
checksum(data_files[0], dataset_attr.dataset_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
else:
raise NotImplementedError
raw_datasets = load_dataset(
data_path,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None
)
dataset = raw_datasets[data_args.split]
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
prefix_data = [dataset_attr.source_prefix] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
("query_column", "query"),
("response_column", "response"),
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
dataset = dataset.add_column("prefix", prefix_data)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0]
else:
all_datasets = concatenate_datasets(all_datasets)
return all_datasets

View File

@@ -1,172 +0,0 @@
from typing import Literal
from itertools import chain
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import Template
from llmtuner.hparams import DataArguments
def preprocess_dataset(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:
column_names = list(dataset.column_names)
prompt_template = Template(data_args.prompt_template)
# support question with a single answer or multiple answers
def get_dialog(examples):
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
yield dialog
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
block_size = data_args.max_source_length - 1
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
for i in range(0, total_length, block_size)]
return {
"input_ids": result,
"labels": result.copy()
}
def preprocess_supervised_dataset(examples):
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length
for dialog in get_dialog(examples):
input_ids, labels = [], []
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_unsupervised_dataset(examples):
# build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length]
model_inputs["input_ids"].append(source_ids)
model_inputs["labels"].append(target_ids)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs
def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
skip_special_tokens=False)
))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
preprocess_function = preprocess_pretrain_dataset
elif stage == "sft":
preprocess_function = preprocess_unsupervised_dataset \
if training_args.predict_with_generate else preprocess_supervised_dataset
elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
preprocess_function = preprocess_unsupervised_dataset
with training_args.main_process_first(desc="dataset map pre-processing"):
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
if stage == "pt":
print_unsupervised_dataset_example(dataset[0])
elif stage == "sft":
print_supervised_dataset_example(dataset[0])
elif stage == "rm":
print_pairwise_dataset_example(dataset[0])
elif stage == "ppo":
print_unsupervised_dataset_example(dataset[0])
return dataset

View File

@@ -0,0 +1 @@
from llmtuner.eval.evaluator import Evaluator

View File

@@ -0,0 +1,116 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import json
import torch
import tiktoken
import numpy as np
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template
from llmtuner.extras.constants import CHOICES, SUBJECTS
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = self._encode_choices()
def _encode_choices(self) -> List[int]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
revision=self.model_args.model_revision
)
with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject,
download_mode="force_redownload"
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
query, resp, history = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=self.template.use_history
)
input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp)
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device)
preds = self.batch_inference(batch_input)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join([
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items() if len(category_correct)
])
print(score_info)
if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False)
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
if __name__ == "__main__":
evaluator = Evaluator()
evaluator.eval()

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
eval_templates: Dict[str, EvalTemplate] = {}
def register_eval_template(
name: str,
system: str,
choice: str,
answer: str,
prefix: str
) -> None:
eval_templates[name] = EvalTemplate(
system=system,
choice=choice,
answer=answer,
prefix=prefix
)
def get_eval_template(name: str) -> EvalTemplate:
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template
register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: ",
prefix=" "
)
register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n"
)

View File

@@ -1,74 +1,162 @@
import os
import json
import time
from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(args.output_dir)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
self.start_time = time.time()
self.tracker = {}
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def timing(self):
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
Event called at the end of training.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
if not state.is_world_process_zero:
if not state.is_local_process_zero:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
self.tracker = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"eval_loss": state.log_history[-1].get("eval_loss", None),
"predict_loss": state.log_history[-1].get("predict_loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
)
if self.runner is not None:
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
))
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n")
f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()

View File

@@ -1,40 +1,278 @@
from collections import defaultdict, OrderedDict
from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
IGNORE_INDEX = -100
VALUE_HEAD_FILE_NAME = "value_head.bin"
LAYERNORM_NAMES = {"norm", "ln"}
FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
SUPPORTED_MODELS = {
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
"PPO": "ppo",
"DPO": "dpo",
"Pre-Training": "pt"
}
def register_model_group(
models: Dict[str, str],
module: Optional[str] = None,
template: Optional[str] = None
) -> None:
prefix = None
for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path
if module is not None:
DEFAULT_MODULE[prefix] = module
if template is not None:
DEFAULT_TEMPLATE[prefix] = template
register_model_group(
models={
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
},
module="W_pack",
template="baichuan"
)
register_model_group(
models={
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
},
module="W_pack",
template="baichuan2"
)
register_model_group(
models={
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1"
},
module="query_key_value"
)
register_model_group(
models={
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
},
module="query_key_value"
)
register_model_group(
models={
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
},
template="bluelm"
)
register_model_group(
models={
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
},
module="query_key_value",
template="chatglm2"
)
register_model_group(
models={
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
},
module="query_key_value",
template="chatglm3"
)
register_model_group(
models={
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
},
template="llama2_zh"
)
register_model_group(
models={
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-180B": "tiiuae/falcon-180B",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
},
module="query_key_value",
template="falcon"
)
register_model_group(
models={
"InternLM-7B": "internlm/internlm-7b",
"InternLM-20B": "internlm/internlm-20b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
},
template="intern"
)
register_model_group(
models={
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
},
module="qkv_proj"
)
register_model_group(
models={
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
}
"LLaMA-65B": "huggyllama/llama-65b"
}
)
DEFAULT_MODULE = { # will be deprecated
"LLaMA": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"InternLM": "q_proj,v_proj"
}
register_model_group(
models={
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
},
template="llama2"
)
register_model_group(
models={
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
},
template="mistral"
)
register_model_group(
models={
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
},
template="openchat"
)
register_model_group(
models={
"Phi1.5-1.3B": "microsoft/phi-1_5"
},
module="Wqkv"
)
register_model_group(
models={
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
},
module="c_attn",
template="qwen"
)
register_model_group(
models={
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
}
)
register_model_group(
models={
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
},
template="vicuna"
)
register_model_group(
models={
"XVERSE-7B": "xverse/XVERSE-7B",
"XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-65B": "xverse/XVERSE-65B",
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
},
template="xverse"
)
register_model_group(
models={
"Yayi-7B": "wenge-research/yayi-7b-llama2",
"Yayi-13B": "wenge-research/yayi-13b-llama2"
},
template="yayi"
)
register_model_group(
models={
"Yi-6B": "01-ai/Yi-6B",
"Yi-34B": "01-ai/Yi-34B"
}
)
register_model_group(
models={
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
},
template="zephyr"
)

View File

@@ -3,11 +3,17 @@ import logging
class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
"""
def __init__(self):
super().__init__()
self.log = ""
def reset(self):
self.log = ""
def emit(self, record):
if record.name == "httpx":
return
@@ -17,7 +23,9 @@ class LoggerHandler(logging.Handler):
def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"
@@ -30,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler)
return logger
def reset_logging() -> None:
r"""
Removes basic config of root logger. (unused in script)
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))

View File

@@ -1,11 +1,25 @@
import gc
import os
import sys
import torch
from typing import List, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
try:
from transformers.utils import (
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers import HfArgumentParser
class AverageMeter:
@@ -28,78 +42,75 @@ class AverageMeter:
self.avg = self.sum / self.count
# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor
def print_trainable_params(model: torch.nn.Module) -> None:
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
finetuning_type: str,
output_embedding_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel:
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
return model
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@@ -0,0 +1,55 @@
import importlib.metadata
import importlib.util
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_rouge_available = is_package_available("rouge-chinese")
_starlette_available = is_package_available("sse-starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble():
return _fastapi_available
def is_flash_attn2_available():
return _flash_attn2_available
def is_jieba_available():
return _jieba_available
def is_matplotlib_available():
return _matplotlib_available
def is_nltk_available():
return _nltk_available
def is_rouge_available():
return _rouge_available
def is_starlette_available():
return _starlette_available
def is_uvicorn_available():
return _uvicorn_available

View File

View File

@@ -0,0 +1,224 @@
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
logger = logging.get_logger(__name__)
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class LlamaShiftShortAttention(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if getattr(self, "num_key_value_groups"):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# cast to half precision
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(self.config.torch_dtype)
key_states = key_states.to(self.config.torch_dtype)
value_states = value_states.to(self.config.torch_dtype)
if getattr(self, "num_key_value_groups", None):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat((
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
), dim=2)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
if attention_mask is not None:
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
# -q_len: assumes left padding when q_len != kv_len
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
attn_output_unpad = flash_attn_varlen_func(
unpadded_q,
unpadded_k,
unpadded_v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat((
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
))
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Disable the transformation of the attention mask in LlamaModel as flash attention
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
def _prepare_decoder_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: torch.Tensor,
inputs_embeds: torch.Tensor,
past_key_values_length: int
) -> torch.Tensor:
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
return attention_mask

View File

@@ -1,11 +1,14 @@
import os
import math
import json
import matplotlib.pyplot as plt
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__)

View File

@@ -1,49 +0,0 @@
import os
import torch
from typing import Dict
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
state_dict = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
return True
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
if not os.path.exists(valuehead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
return True

View File

@@ -1,221 +0,0 @@
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
@dataclass
class Format:
prefix: str
prompt: str
sep: str
use_history: bool
templates: Dict[str, Format] = {}
@dataclass
class Template:
name: str
def __post_init__(self):
if self.name in templates:
self.prefix = templates[self.name].prefix
self.prompt = templates[self.name].prompt
self.sep = templates[self.name].sep
self.use_history = templates[self.name].use_history
else:
raise ValueError("Template {} does not exist.".format(self.name))
def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> str:
r"""
Returns a string containing prompt without response.
"""
return "".join(self._format_example(query, history, prefix))
def get_dialog(
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
r"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
"""
return self._format_example(query, history, prefix) + [resp]
def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
prefix = prefix if prefix else self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else []
history = history + [(query, "<dummy>")]
convs = []
for turn_idx, (user_query, bot_resp) in enumerate(history):
if turn_idx == 0:
convs.append(prefix + self.prompt.format(query=user_query))
convs.append(bot_resp)
else:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp)
return convs[:-1] # drop last
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Format(
prefix=prefix,
prompt=prompt,
sep=sep,
use_history=use_history
)
r"""
Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix="",
prompt="{query}",
sep="",
use_history=False
)
r"""
Default template.
"""
register_template(
name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
register_template(
name="alpaca",
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
register_template(
name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
use_history=True
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
)
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
)
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_template(
name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix="",
prompt=" <reserved_102> {query} <reserved_103> ",
sep="</s>",
use_history=True
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True
)

View File

@@ -1,5 +1,5 @@
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .general_args import GeneralArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments

View File

@@ -1,6 +1,6 @@
import os
import json
from typing import List, Optional
from typing import List, Literal, Optional
from dataclasses import dataclass, field
@@ -10,35 +10,72 @@ class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None
system_prompt: Optional[str] = None
subset: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
role: Optional[str] = "from"
content: Optional[str] = "value"
def __repr__(self) -> str:
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
@dataclass
class DataArguments:
"""
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: Optional[str] = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field(
default="alpaca_zh",
default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "The name of the folder containing datasets."}
metadata={"help": "Path to the folder containing the datasets."}
)
split: Optional[str] = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
cutoff_len: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable dataset streaming."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
)
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
@@ -47,14 +84,6 @@ class DataArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."}
)
max_source_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total input sequence length after tokenization."}
)
max_target_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total output sequence length after tokenization."}
)
max_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
@@ -67,30 +96,53 @@ class DataArguments:
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
)
source_prefix: Optional[str] = field(
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
)
dev_ratio: Optional[float] = field(
val_size: Optional[float] = field(
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
)
prompt_template: Optional[str] = field(
default="default",
metadata={"help": "Which template to use for constructing prompts in training and inference."}
sft_packing: Optional[bool] = field(
default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
)
cache_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the preprocessed datasets."}
)
def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
def __post_init__(self):
if self.reserved_label_len >= self.cutoff_len:
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.streaming and self.cache_path:
raise ValueError("`cache_path` is incompatible with `streaming`.")
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
except Exception:
if self.dataset is not None:
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
dataset_info = None
if self.source_prefix is not None:
prefix_list = self.source_prefix.split("|")
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names):
@@ -108,12 +160,17 @@ class DataArguments:
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)

View File

@@ -0,0 +1,55 @@
import os
from typing import Literal, Optional
from dataclasses import dataclass, field
from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field(
metadata={"help": "Name of the evaluation task."}
)
task_dir: Optional[str] = field(
default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."}
)
batch_size: Optional[int] = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."}
)
seed: Optional[int] = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."}
)
lang: Optional[Literal["en", "zh"]] = field(
default="en",
metadata={"help": "Language used at evaluation."}
)
n_shot: Optional[int] = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."}
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."}
)
download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."}
)
def __post_init__(self):
task_available = []
for folder in os.listdir(self.task_dir):
if os.path.isdir(os.path.join(self.task_dir, folder)):
task_available.append(folder)
if self.task not in task_available:
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.")

View File

@@ -4,75 +4,181 @@ from dataclasses import asdict, dataclass, field
@dataclass
class FinetuningArguments:
class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_hidden_layers: Optional[int] = field(
default=32,
metadata={"help": "Number of decoder blocks in the model. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \
Falcon choices: [\"32\", \"60\"], \
Baichuan choices: [\"32\"]"}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
name_module_trainable: Optional[str] = field(
default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"}
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
Others choices: the same as LLaMA."}
)
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
)
lora_dropout: Optional[float] = field(
default=0.1,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_target: Optional[str] = field(
default="q_proj,v_proj",
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
Others choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
)
ppo_target: Optional[float] = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
)
ppo_whiten_rewards: Optional[bool] = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
)
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
)
ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
reward_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."}
)
reward_model_type: Optional[Literal["lora", "full"]] = field(
default="lora",
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
neft_alpha: Optional[float] = field(
default=0,
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
self.name_module_trainable = split_arg(self.name_module_trainable)
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
def save_to_json(self, json_path: str):
"""Saves the content of this instance in JSON format inside `json_path`."""
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
"""Creates an instance from the content of `json_path`."""
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@@ -1,13 +0,0 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)

View File

@@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
@dataclass
class GeneratingArguments:
"""
r"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(
@@ -28,7 +28,7 @@ class GeneratingArguments:
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_length: Optional[int] = field(
default=None,
default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
)
max_new_tokens: Optional[int] = field(
@@ -46,6 +46,8 @@ class GeneratingArguments:
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", None):
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
return args

View File

@@ -1,11 +1,10 @@
import torch
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class ModelArguments:
"""
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
@@ -16,21 +15,17 @@ class ModelArguments:
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
)
use_fast_tokenizer: Optional[bool] = field(
default=False,
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
use_auth_token: Optional[bool] = field(
split_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
)
model_revision: Optional[str] = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
)
padding_side: Optional[Literal["left", "right"]] = field(
default="left",
metadata={"help": "The side on which the model should have padding applied."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
@@ -43,30 +38,38 @@ class ModelArguments:
default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."}
)
compute_dtype: Optional[torch.dtype] = field(
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
metadata={"help": "Adopt scaled rotary positional embeddings."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
metadata={"help": "Enable FlashAttention-2 for faster training."}
)
shift_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
def __post_init__(self):
self.compute_dtype = None
self.model_max_length = None
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@@ -0,0 +1,5 @@
# Level: loader > adapter > parser, utils
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params

View File

@@ -0,0 +1,103 @@
import torch
from typing import TYPE_CHECKING
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
from llmtuner.model.utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.info("Checkpoint is not found at evaluation, load the original model.")
return model
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name))
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
checkpoint_to_resume = None
if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training:
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if checkpoint_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else:
target_modules = finetuning_args.lora_target
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=target_modules,
modules_to_save=finetuning_args.additional_target
)
model = get_peft_model(model, lora_config)
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model

View File

@@ -0,0 +1,226 @@
import os
import math
import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.integrations import is_deepspeed_zero3_enabled
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs
)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
# Set RoPE scaling
if model_args.rope_scaling is not None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
if is_flash_attn2_available():
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) == "llama":
setattr(config, "group_size_ratio", 0.25)
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
return model, tokenizer

View File

@@ -0,0 +1,210 @@
import os
import torch
import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments,
GeneratingArguments
)
logger = get_logger(__name__)
_TRAIN_ARGS = [
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args)
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args)
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments
data_args.init_for_training(training_args.seed)
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"]:
if training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)
# postprocess model_args
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

165
src/llmtuner/model/utils.py Normal file
View File

@@ -0,0 +1,165 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None
) -> List[str]:
r"""
Finds all available modules to apply lora.
"""
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
else:
linear_cls = torch.nn.Linear
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def generate_model_card(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
}
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
if finetuning_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
logger.info("Upcasting weights in layernorm in float32.")
if finetuning_args.neft_alpha > 1e-6:
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear):
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
return args[0].to(output_layer.weight.dtype)
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
output_layer.register_forward_hook(fp32_forward_post_hook)
return model

View File

@@ -0,0 +1 @@
from llmtuner.train.tuner import export_model, run_exp

View File

@@ -0,0 +1 @@
from llmtuner.train.dpo.workflow import run_dpo

View File

@@ -0,0 +1,51 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, prompt_len + answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@@ -0,0 +1,75 @@
import torch
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers import PreTrainedModel
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
**kwargs
):
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.is_encoder_decoder = model.config.is_encoder_decoder
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = beta
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
getattr(ref_model, "is_loaded_in_8bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
all_logps = self._get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits

View File

@@ -0,0 +1,85 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Create reference model
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
model=model,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -0,0 +1 @@
from llmtuner.train.ppo.workflow import run_ppo

View File

@@ -0,0 +1,338 @@
import os
import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
)
self.state = TrainerState()
self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if reward_model is not None:
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
if is_deepspeed_enabled:
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)
if self.args.max_steps > 0:
num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = sys.maxsize
max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * len_dataloader)
steps_in_epoch = len_dataloader
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
try:
batch = next(dataiter)
except StopIteration:
dataiter = iter(self.dataloader)
batch = next(dataiter)
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
self.model.eval()
# Get inputs
self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
self.model.train()
# Run PPO step
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.config.log_with is not None:
try:
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)
except:
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / steps_in_epoch, 2)
)
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control)
loss_meter.reset()
reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
))
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.log_callback.on_train_end(self.args, self.state, self.control)
self.save_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad()
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
if self.finetuning_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
logits_processor=get_logits_processor(),
**batch
)
if self.finetuning_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
else:
response_length = response_index[-1].item() + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
return queries, responses
@torch.no_grad()
def get_rewards(
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead"
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
"""
if self.reward_model is None:
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
reward_model = self.reward_model if self.reward_model is not None else self.model
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.reward_model is None:
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
if values.size(0) != input_ids.size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
Subclass and override to inject custom behavior.
"""
if self.args.should_save:
self._save(output_dir)

View File

@@ -0,0 +1,35 @@
import torch
from typing import TYPE_CHECKING, Dict, Literal, Optional
if TYPE_CHECKING:
from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
})
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
layer_norm_params = {}
for name, param in model.named_parameters():
if param.data.dtype == torch.float32:
layer_norm_params[name] = param.data.detach().clone()
param.data = param.data.to(model.config.torch_dtype)
return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
for name, param in model.named_parameters():
if name in layernorm_params:
param.data = layernorm_params[name]

View File

@@ -0,0 +1,101 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from trl import PPOConfig
from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
)
# Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps
)
# Initialize our Trainer
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
# Training
if training_args.do_train:
ppo_trainer.ppo_train()
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -0,0 +1 @@
from llmtuner.train.pt.workflow import run_pt

View File

@@ -0,0 +1,65 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import generate_model_card, load_model_and_tokenizer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -0,0 +1 @@
from llmtuner.train.rm.workflow import run_rm

View File

@@ -1,8 +1,10 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
@@ -15,5 +17,11 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
features = [
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
}
for key in ("chosen_ids", "rejected_ids") for feature in features
]
return super().__call__(features)

View File

@@ -0,0 +1,101 @@
import os
import json
import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
logger = get_logger(__name__)
class PairwiseTrainer(Trainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self,
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
# Split the inputs and rewards into two parts, chosen and rejected
batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0:
end_index = chosen_length
div_index = end_index - 1
else:
end_index = max(chosen_length, rejected_length)
div_index = check_divergence[0]
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1])
rejected_scores.append(rejected_rewards[i, rejected_length-1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size
if return_outputs:
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
return loss, [loss, chosen_scores, rejected_scores]
return loss
def save_predictions(
self,
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
writer.write("\n".join(res))

View File

@@ -0,0 +1,75 @@
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = PairwiseTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks + [SavePeftModelCallback()],
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -0,0 +1 @@
from llmtuner.train.sft.workflow import run_sft

View File

@@ -1,13 +1,23 @@
import numpy as np
from dataclasses import dataclass
from typing import Dict, Sequence, Tuple, Union
from transformers.tokenization_utils import PreTrainedTokenizer
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba
if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
if is_rouge_available():
from rouge_chinese import Rouge
@dataclass
@@ -16,14 +26,14 @@ class ComputeMetrics:
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
"""
tokenizer: PreTrainedTokenizer
tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -47,6 +57,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@@ -3,18 +3,20 @@ import json
import torch
import numpy as np
import torch.nn as nn
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
logger = get_logger(__name__)
class Seq2SeqPeftTrainer(PeftTrainer):
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
@@ -31,44 +33,40 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Subclass and override to inject custom behavior.
"""
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
loss, generated_tokens, labels = super().prediction_step(
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
generated_tokens = generated_tokens.contiguous()
return (loss, generated_tokens, labels)
return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
def _pad_tensors_to_target_len(
self,
src_tensor: torch.Tensor,
tgt_tensor: torch.Tensor
) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
Should only be called when predict_with_generate=True.
"""
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(
self,
predict_results: PredictionOutput
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.

View File

@@ -1,79 +1,75 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = Seq2SeqPeftTrainer(
finetuning_args=finetuning_args,
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
**split_dataset(dataset, data_args, training_args)
)
# Keyword arguments for `model.generate`
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_new_tokens": data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
# Training
if training_args.do_train:
train_result = trainer.train()
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
@@ -92,3 +88,10 @@ def run_sft(
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -0,0 +1,51 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.train.pt import run_pt
from llmtuner.train.sft import run_sft
from llmtuner.train.rm import run_rm
from llmtuner.train.ppo import run_ppo
from llmtuner.train.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.config.use_cache = True
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(finetuning_args.export_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
run_exp()

View File

@@ -0,0 +1,80 @@
import torch
from typing import TYPE_CHECKING, Literal, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"]
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.ref_model,
checkpoint_dir=finetuning_args.ref_model_checkpoint,
quantization_bit=finetuning_args.ref_model_quantization_bit
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict(
model_name_or_path=finetuning_args.reward_model,
checkpoint_dir=finetuning_args.reward_model_checkpoint,
quantization_bit=finetuning_args.reward_model_quantization_bit
))
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model

View File

@@ -1,5 +0,0 @@
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo

View File

@@ -1,2 +0,0 @@
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
from llmtuner.tuner.core.loader import load_model_and_tokenizer

View File

@@ -1,94 +0,0 @@
import os
import torch
from transformers.modeling_utils import PreTrainedModel
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
def init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: bool,
is_mergeable: bool
) -> PreTrainedModel:
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full":
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None
if model_args.checkpoint_dir is not None:
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if latest_checkpoint is not None: # resume lora training or quantized inference
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
if is_trainable and latest_checkpoint is None: # create new lora weights while training
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
model = get_peft_model(model, lora_config)
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model

View File

@@ -1,151 +0,0 @@
import os
import torch
from typing import Literal, Optional, Tuple
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
logger = get_logger(__name__)
check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side=model_args.padding_side,
**config_kwargs
)
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
tokenizer.pad_token_id = 0 # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
is_mergeable = True
# Quantization configurations (using bitsandbytes library).
if model_args.quantization_bit is not None:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if not is_trainable: # `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto"
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
# Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
low_cpu_mem_usage=True,
**config_kwargs
)
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model)
return model, tokenizer

View File

@@ -1,134 +0,0 @@
import os
import sys
import torch
import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
)
logger = get_logger(__name__)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
if args is not None:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."
assert (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training.")
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
training_args.ddp_find_unused_parameters = False
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
if training_args.fp16:
model_args.compute_dtype = torch.float16
elif training_args.bf16:
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float32
# Log on each process the small summary:
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, general_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
if args is not None:
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
return model_args, data_args, finetuning_args, generating_args

View File

@@ -1,88 +0,0 @@
import os
import torch
from typing import Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self._remove_log()
def _remove_log(self):
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
backbone_model = getattr(model, "pretrained_model")
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
else:
backbone_model = model
if self.finetuning_args.finetuning_type == "lora":
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model),
safe_serialization=self.args.save_safetensors
)
backbone_model.config.use_cache = False
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if self.finetuning_args.finetuning_type == "lora":
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
else: # freeze/full-tuning
load_trainable_params(backbone_model, self.state.best_model_checkpoint)

View File

@@ -1 +0,0 @@
from llmtuner.tuner.ppo.workflow import run_ppo

View File

@@ -1,186 +0,0 @@
import os
import math
import torch
from tqdm import tqdm
from typing import Callable, Dict, List, Optional
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer
from trl.core import LengthSampler
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
logger = get_logger(__name__)
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: List[LogCallback],
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.finetuning_args = finetuning_args
self.log_callback = callbacks[0]
self.state = TrainerState()
self.control = TrainerControl()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
self._remove_log()
def ppo_train(self, max_target_length: int) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * len_dataloader)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
# Keyword arguments for `model.generate`
gen_kwargs = {
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter()
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
batch = next(dataiter)
steps_trained += 1
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
# Get responses
query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
queries, responses = [], []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default")
# Run PPO step
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
)
print(logs)
logs["step"] = step
self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control)
loss_meter.reset()
reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
@torch.no_grad()
def generate(
self,
inputs: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None,
return_prompt: Optional[bool] = True,
**generation_kwargs
) -> torch.Tensor:
r"""
Generates model's responses given queries.
Subclass and override to inject custom behavior.
"""
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
unwrapped_model = self.accelerator.unwrap_model(self.model)
response = unwrapped_model.generate(**inputs, **generation_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
Subclass and override to inject custom behavior.
"""
if self.args.should_save:
self._save(output_dir)

View File

@@ -1,37 +0,0 @@
import torch
from typing import Dict, List, Literal, Optional, Tuple
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.constants import LAYERNORM_NAMES
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
valuehead_state_dict = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({
"summary.weight": getattr(model, "{}_head_weight".format(target)),
"summary.bias": getattr(model, "{}_head_bias".format(target))
})
def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead,
layer_norm_names: List[str] = LAYERNORM_NAMES,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
layer_norm_state_dict = {}
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
if layer_norm_params is not None:
param.data = layer_norm_params[name] # restore float32 weights
else:
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
param.data = param.data.to(torch.float16)
return model, layer_norm_state_dict

View File

@@ -1,70 +0,0 @@
# Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
import math
from trl import PPOConfig
from torch.optim import AdamW
from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
def run_ppo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
total_train_batch_size = \
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
)
# Initialize our Trainer
ppo_trainer = PPOPeftTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=callbacks,
config=ppo_config,
model=model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -1 +0,0 @@
from llmtuner.tuner.pt.workflow import run_pt

View File

@@ -1,73 +0,0 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
import math
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
def run_pt(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PeftTrainer(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**trainer_kwargs
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

View File

@@ -1 +0,0 @@
from llmtuner.tuner.rm.workflow import run_rm

View File

@@ -1,38 +0,0 @@
import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from llmtuner.tuner.core.trainer import PeftTrainer
class PairwisePeftTrainer(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self,
model: PreTrainedModel,
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss

View File

@@ -1,68 +0,0 @@
# Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
def run_rm(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
training_args.remove_unused_columns = False # important for pairwise dataset
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
**trainer_kwargs
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

View File

@@ -1 +0,0 @@
from llmtuner.tuner.sft.workflow import run_sft

Some files were not shown because too many files have changed in this diff Show More