139 Commits

Author SHA1 Message Date
hiyouga
c8fe3f544b release v0.6.3
Former-commit-id: 947572af8de201669598f54735f35b50bb719d71
2024-04-21 23:13:23 +08:00
hiyouga
0f1ad7140f fix #3366
Former-commit-id: dc20237455c36de44f8922539d7dfadd8bedb12f
2024-04-21 21:34:25 +08:00
hiyouga
233e167f68 fix optimizers
Former-commit-id: f811eee2fa12a89a55a9c5d3a05a1521b4347727
2024-04-21 20:40:54 +08:00
hiyouga
1d341dcd83 fix #3365
Former-commit-id: 415ce41e8fa887e980e5bd575c8e95bd4076b90b
2024-04-21 19:20:18 +08:00
hiyouga
d16561e7a4 fix bug in galore optimizer
Former-commit-id: c05ac23261a5a8ba893c2918a43dc7777307407b
2024-04-21 18:53:22 +08:00
hiyouga
f8e219dc81 fix mod stuff
Former-commit-id: cf3988226e6398c67bb2955578e436fc505aa5c5
2024-04-21 18:11:10 +08:00
hoshi-hiyouga
3365cc8cf0 Merge pull request #3338 from astramind-ai/main
Adding Mixture of Depth

Former-commit-id: 4da2ece53353b63e672ff529d6beba41ff710c14
2024-04-21 18:05:52 +08:00
hoshi-hiyouga
3a5e68b7d9 fix #3348
Former-commit-id: aa5e921c00f60074eceb2f9d4d8837cc713edba6
2024-04-20 10:34:09 +08:00
hiyouga
0cb596fee1 add dpo mix dataset
Former-commit-id: 6def3f8bfa51b2d9d73af112352ce07db972e4c9
2024-04-20 01:31:38 +08:00
hiyouga
b3b5b530d1 fix #3352
Former-commit-id: f315f8e8ec916b82bac94a159e55839ff155c6b5
2024-04-19 22:40:01 +08:00
hiyouga
9225c15c88 fix llama3 template
Former-commit-id: 20e95250168fbe081c779b2e1ff23f5df3ce02f7
2024-04-19 15:46:51 +08:00
Marco
abd9fed445 fix small typo
Former-commit-id: 5638a03cd0cf8119ff366b3b3e303b5a2351b065
2024-04-18 20:33:29 +02:00
Marco
44cda2eece Added Mixture of Depths
Former-commit-id: 75dd98b9abc847e22cb263c17ebcd2ca5dd98345
2024-04-18 20:31:24 +02:00
hoshi-hiyouga
8397808d1d support llama3
Former-commit-id: c1eabb751a5fd73b710714451b146732e0ed4558
2024-04-19 01:13:50 +08:00
hiyouga
9e1bd6420d fix #3324
Former-commit-id: 5e710c4ac331f3400534d33b2646c4108c898d98
2024-04-18 15:34:45 +08:00
hiyouga
619264c854 tiny fix
Former-commit-id: 86399ca8c06273c42c2b184664ae25d3405b3bf6
2024-04-18 00:22:17 +08:00
hiyouga
1ebac62e3d update readme
Former-commit-id: a49112a74339ba77bfec53f7870e821fe148db2c
2024-04-17 23:40:49 +08:00
hiyouga
ce9bdb3509 add mixtral 8x22B models
Former-commit-id: eccbeecff0909e1fa124b5439ffbbfbc5607e1d6
2024-04-17 23:35:59 +08:00
hiyouga
0c8d6369ac add CodeQwen models
Former-commit-id: 9f6094241391f8f717818c8ba94e11d1791b4a5c
2024-04-17 23:27:22 +08:00
hiyouga
bee796f6b5 fix #3316
Former-commit-id: 7395e9e90a209228ff563ab54319955608850fc3
2024-04-17 22:54:34 +08:00
hiyouga
9f6349a333 fix #3317
Former-commit-id: 7dce1763be4374cf616d96db95ae964ff510a9d6
2024-04-17 22:17:19 +08:00
hiyouga
171a029c5e lint
Former-commit-id: 917d65ce65024d17a5030bc57083a427cfae16d7
2024-04-16 18:21:09 +08:00
hoshi-hiyouga
eaefaa0fe0 Merge pull request #3291 from codemayq/main
support for previewing custom dataset in directory format

Former-commit-id: 40d89152282101a7c08f53e72c2ad7124a0595f3
2024-04-16 18:12:09 +08:00
hiyouga
d301f0a64b Update parser.py
Former-commit-id: 92c2133896c20054db86dd53508c982e39bd5ca0
2024-04-16 18:09:31 +08:00
hiyouga
0a1578e4e3 update readme and gradio version
Former-commit-id: 4029b60ddcbd15b5354503c51178f0f5e7e9aedf
2024-04-16 18:09:16 +08:00
hiyouga
a4167fd925 support badam for all stages
Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
2024-04-16 17:44:48 +08:00
hoshi-hiyouga
42084e08ae Merge pull request #3287 from Ledzy/badam
[Feature] Add BAdam algorithm

Former-commit-id: 10a5e1e65b34b03e5ca2a41bf6ded09a3fb25f0c
2024-04-16 17:32:16 +08:00
hoshi-hiyouga
9d23f5dc89 Update utils.py
Former-commit-id: 01147536b2bb507e87e033fa696e9eb39fe96bbe
2024-04-16 17:30:12 +08:00
hoshi-hiyouga
5978427ae0 Update trainer.py
Former-commit-id: c6163be1444c00dd000f288e2f834968bd932981
2024-04-16 17:29:52 +08:00
hoshi-hiyouga
c7c216069c Update utils.py
Former-commit-id: 7edf4dbed88b8034282f14fd6e0cb6f7f9e5f805
2024-04-16 17:29:30 +08:00
hoshi-hiyouga
cde9d1b917 Update patcher.py
Former-commit-id: 494e6a1e05b38f5ff61d83327303614f53c92e64
2024-04-16 17:29:19 +08:00
hoshi-hiyouga
96213f04b0 Update adapter.py
Former-commit-id: 8f7b75b26f020d8ae85baab7b082475c3bfeb512
2024-04-16 17:28:12 +08:00
hoshi-hiyouga
7ecea08b9b Update parser.py
Former-commit-id: 898239883afc79f03abd0dc276eef901662a9591
2024-04-16 17:27:25 +08:00
hoshi-hiyouga
191971865d Update parser.py
Former-commit-id: 2f3da8169d18b026760cc0ac7dd6141bdd08c932
2024-04-16 17:27:02 +08:00
hoshi-hiyouga
ff4f587dd9 Update finetuning_args.py
Former-commit-id: 3a23d900aea74078f0bc8cf73fac860a4ce3df67
2024-04-16 17:26:30 +08:00
hoshi-hiyouga
de728d0371 Update sft.sh
Former-commit-id: 2b4b1562e91bbb02e345e71b7721da9333c0791b
2024-04-16 17:25:40 +08:00
hoshi-hiyouga
d08e09642d Update requirements.txt
Former-commit-id: 1e45537ca0bb4d49b4147df01122e365b3d617e4
2024-04-16 17:10:17 +08:00
hoshi-hiyouga
351493b183 Update setup.py
Former-commit-id: 5df30ea166aff29d48ff83a22ac6ef1611ce3e35
2024-04-16 17:10:02 +08:00
Jonery
86ab47e121 remove badam from core requirements
Former-commit-id: fa5898944a3867ac5108dd0d579ca0677c87d3d6
2024-04-16 12:25:50 +08:00
Jonery
6dd6b3e396 resolve gradient checkpointing issue.
Former-commit-id: 6df9135d063bb6102f0cbcdf0d702076f5febbae
2024-04-16 12:05:27 +08:00
codingma
5f1418a68b add check
Former-commit-id: 008f6498977c243c80e87242f05c9cf9573541ac
2024-04-16 10:56:39 +08:00
codingma
7b97a79efc support for previewing custom dataset in directory format
Former-commit-id: 501cff38c819f06f15194907ce7e052d5f28025a
2024-04-16 10:43:14 +08:00
hiyouga
ce4f653121 add empty template
Former-commit-id: a325ffa8a668bec354d2636683806acef105e196
2024-04-16 03:10:02 +08:00
hiyouga
b053c6454e update readme
Former-commit-id: 8f233745c3aa7a6ef57f275bec80ee731ff76de3
2024-04-16 02:36:54 +08:00
hiyouga
ebf0f4a77c update readme
Former-commit-id: f9a246572c1ec0e4b36bff237c6523ce629b7000
2024-04-16 02:35:36 +08:00
hiyouga
efa808069a support unsloth 2024.4
Former-commit-id: 14a83f8bc4fe44783252378fce59198194a96bb8
2024-04-16 00:25:03 +08:00
hiyouga
b5c5283dd6 add codegemma
Former-commit-id: 9324176525c2eda22962b0ca1895009b6237e6e3
2024-04-16 00:11:15 +08:00
hiyouga
b638c65519 support cohere commandR #3184
Former-commit-id: e077c36872740f6b2ac255aee9da6c4c70f28977
2024-04-15 23:26:42 +08:00
Jonery
d4d471450f Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
2024-04-15 23:15:27 +08:00
hoshi-hiyouga
3144bdec2c Merge pull request #3254 from marko1616/feature/Add-support-for-CohereForAI/c4ai-command-r-plus
Add template&support for c4ai-command-r/plus (tested)

Former-commit-id: 41d39ec4889abad050820bf153133ac3a11228a3
2024-04-15 22:59:35 +08:00
hoshi-hiyouga
c6d6c4c209 Update template.py
Former-commit-id: 00b8be7dafa65e13b344724a8d3855919ee4f631
2024-04-15 22:58:01 +08:00
hoshi-hiyouga
f5f1589662 Update constants.py
Former-commit-id: 39199f712aa7b7a1c66080d9c84651fd2eb0b425
2024-04-15 22:56:55 +08:00
hiyouga
276f2cb24e update examples
Former-commit-id: 369294b31c8a03a1cafcee83eb31a817007d3c49
2024-04-15 22:14:34 +08:00
marko1616
952b785bb3 change default_system accroding to official template
Former-commit-id: 7ad9029c5e77a87a7c324b8f90b4f80a31a5c78b
2024-04-15 20:45:46 +08:00
marko1616
72dd676208 Revert "Add support for function call(Not strictly following origin)"
This reverts commit dfaa31e991 [formerly 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6].


Former-commit-id: fac9cc6e01dd8f3bc449b656804476e1871326f0
2024-04-15 20:27:09 +08:00
marko1616
dfaa31e991 Add support for function call(Not strictly following origin)
Former-commit-id: 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6
2024-04-15 20:16:52 +08:00
hoshi-hiyouga
86556b1c74 Merge pull request #3261 from khazic/main
Added specimens for single-card full parameter prediction

Former-commit-id: 60df2a9519fbd8215c3afacc831b0cc89006457a
2024-04-15 16:30:57 +08:00
hoshi-hiyouga
0c80751e87 Merge pull request #3276 from liu-zichen/fix_mixtral
fix: turn on output_router_logits of mixtral
Former-commit-id: 07bbaf5c67d00a152e5304e81b15fd9189e7bb99
2024-04-15 15:38:16 +08:00
hiyouga
9338f878a3 fix #3273
Former-commit-id: 3b20c89b342a068356ffc29c3724b645775c65db
2024-04-15 15:32:58 +08:00
liuzc
fde3d91242 fix: mixtral output_router_logits
Former-commit-id: ab3171ea97ec968b972287287ef9ee2502c6d37c
2024-04-15 12:11:49 +08:00
khazic
19adfb88a9 Upgrade README.md
Former-commit-id: 697f768d7185789ee054c94f4f161a65b8a505bc
2024-04-13 20:50:49 +08:00
khazic
daaafa900a Added specimens for single-card full parameter prediction
Former-commit-id: d8d4fb9fa4b0e1950a453682e5e186f34f085dee
2024-04-13 20:45:19 +08:00
marko1616
0dcc9e0bca Typo fix
Former-commit-id: 607625497738b2c8be736be7b0bd5c6f4cbaad5e
2024-04-13 17:30:21 +08:00
marko1616
aeec78b35c Typo fix
Former-commit-id: 51b1e49e288e66c1b0c24ac070201c988fb2a389
2024-04-13 07:52:11 +08:00
marko1616
c991654cb4 Add c4ai-command-r-plus link
Former-commit-id: acaf953ca46eca8fb378067f4ada133654e4f088
2024-04-13 07:32:40 +08:00
marko1616
f328413646 Add template&support(Not tested)
Former-commit-id: 60bb60c4dc30a9641ddb57a44ef126f0768566c4
2024-04-13 04:31:33 +08:00
hiyouga
106a0104da fix #3247
Former-commit-id: bb67c66f80627805b585d157ba807c0ce378d3f2
2024-04-12 17:41:33 +08:00
hiyouga
5486ea09e3 fix model card
Former-commit-id: 920e7149bf2b559c9829aa4b11cfb6d00bbb2f9e
2024-04-12 17:11:59 +08:00
hiyouga
31bbbb6d13 fix #3238
Former-commit-id: 4d7e81ab4722d13bec6ca1af141f94bdc74d0883
2024-04-12 14:28:11 +08:00
hiyouga
1a77de82fa set dev version
Former-commit-id: f6cc76571d2c789675883a18e0db3d0c61f33808
2024-04-11 20:27:34 +08:00
hiyouga
7468f2535c release v0.6.2
Former-commit-id: f92ad0a62d957b595f6a76a5403216b163eb3d17
2024-04-11 20:08:51 +08:00
hiyouga
38e4f22605 Merge branch 'main' of https://github.com/hiyouga/LLaMA-Factory
Former-commit-id: 23ff02c1fd3787daf0bc6ac237c8897d02f726e4
2024-04-10 23:58:18 +08:00
hiyouga
2bc2fe7b5e fix #3225
Former-commit-id: 94110ecf27c32e263f1f2ee61842a3a301b9e089
2024-04-10 23:57:59 +08:00
hoshi-hiyouga
6d0140d8a0 Merge pull request #3201 from kno10/patch-1 and fix #3200
Pass additional_target to unsloth

Former-commit-id: 080a96c52f489fda0d315a77e26c4f6f5d69784a
2024-04-10 00:58:48 +08:00
hoshi-hiyouga
7856f98965 Update adapter.py
Former-commit-id: 720fde3683529ed7e08ac27c7c4598c6bdc30d44
2024-04-10 00:57:51 +08:00
hoshi-hiyouga
e25ddef08c Update adapter.py
Former-commit-id: a84b8d17dbf221259212e81931d80bcdd6284ad7
2024-04-10 00:57:30 +08:00
Erich Schubert
95a4589bbf Pass additional_target to unsloth
Fixes #3200

Former-commit-id: f8f87f5b0549cba6a011749c42064047f82ba577
2024-04-09 17:53:40 +02:00
hiyouga
566d71b7a9 fix quant infer and qwen2moe
Former-commit-id: b75d16767f35c36e2cf2aaab8a3844135085bccf
2024-04-09 17:12:59 +08:00
hiyouga
6030a4a720 tiny fix
Former-commit-id: d8f1ff51d4c920d4d0aeb9d53db29d1efb733c85
2024-04-08 21:28:39 +08:00
hoshi-hiyouga
5dc0cb94d4 Merge pull request #3161 from hiyouga/feature/add-mediatek-model
support Breeze-7B

Former-commit-id: af92ac8b62b919a75673011a1c56832e67882ee8
2024-04-08 20:56:51 +08:00
codingma
325dafcbb0 add empty line
Former-commit-id: 1c6c2e611d10e9fa662e3f4e1e7d23b80ae496cb
2024-04-07 18:28:08 +08:00
codingma
1a8a8b8651 rename template to breeze
Former-commit-id: 1223e6358dab52b4e1505057f1b16fd9d527c79e
2024-04-07 18:27:20 +08:00
hoshi-hiyouga
61a495cb1e Merge pull request #3160 from sliderSun/main
support Qwen1.5-32B

Former-commit-id: 1e5a5882dd494c3e9cf5eae2e0a485ce49d1863c
2024-04-07 18:00:40 +08:00
codingma
75866aa020 rename template to breeze
Former-commit-id: 1d894e7cfb73b8a29dababb554d051bd50e4f01d
2024-04-07 11:39:54 +08:00
codingma
9e4fda326d support https://github.com/hiyouga/LLaMA-Factory/issues/3152
Former-commit-id: 708f0ab4b0aa72e2c73ca36eb9ed058910e43092
2024-04-07 11:34:01 +08:00
sliderSun
1131ddfaff fix spell error
Former-commit-id: e6d36a2e593ebc1193b1735075c4ddb5d9f54990
2024-04-07 10:59:15 +08:00
sliderSun
9f437b5c43 support Qwen1.5-32B
Former-commit-id: c419adf1697b92520342f4ffa697c84bf19ca37d
2024-04-07 10:56:03 +08:00
sliderSun
0cc03d3f05 support Qwen1.5-32B
Former-commit-id: 8f2c67b95a8e177eb4096382417a70cacba38e90
2024-04-07 10:26:13 +08:00
hiyouga
04fc2f78bf update readme
Former-commit-id: 1cf15547e2420a3e5f7a969c21c10c7fbdfc71fe
2024-04-07 00:48:24 +08:00
hiyouga
3ac333fc6a update examples
Former-commit-id: de40ad62ba3d4c74c69de97b39cc79786ac28f0f
2024-04-04 14:48:21 +08:00
hiyouga
a246ac1914 tiny fix
Former-commit-id: 70aceecb27e72095c05462d01f956061669b267e
2024-04-04 02:19:03 +08:00
hiyouga
48ceac845c back to gradio 4.21 and fix chat
Former-commit-id: 695734a40a702ea059d855da54080cc8d161e41a
2024-04-04 02:07:20 +08:00
hiyouga
b1986a06b9 fix bug in latest gradio
Former-commit-id: 44a962862b4a74e50ef5786c8d5719faaa65f63f
2024-04-04 00:55:31 +08:00
hiyouga
43d134ba29 fix requires for windows
Former-commit-id: 5e25fae40b7ea9cfa72717efbe3677199ca9608f
2024-04-03 21:56:43 +08:00
hiyouga
1348f7d860 fix resize vocab at inference #3022
Former-commit-id: c243720b89eec0af2872fa3c7980a0026d893f4d
2024-04-03 18:14:24 +08:00
hiyouga
f6530222f7 fix #3116
Former-commit-id: b7256aa33d761280751518c20f29f9b8ea3fb025
2024-04-03 14:47:59 +08:00
hiyouga
a74a7585e0 update vllm example
Former-commit-id: 2df6d2eacfa27ebc69455696b93649624c1facbe
2024-04-02 22:45:20 +08:00
hiyouga
5bf0cca2b8 update readme
Former-commit-id: 7ea7333b51be6b1120fc0b13675f5a0ac3c5a12b
2024-04-02 22:17:48 +08:00
hiyouga
755b6511ff update examples
Former-commit-id: 2715cfe20f6f4532bebaa47b80ccd5df43d6a490
2024-04-02 21:09:25 +08:00
hiyouga
35621c6089 add zh readme
Former-commit-id: 389a170a4d42c56c71c0e17bbe018c4cb1983b5a
2024-04-02 20:58:45 +08:00
hiyouga
38b59664e6 update examples
Former-commit-id: c078582a759f6bce6e760cd39a05883f7eb194fe
2024-04-02 20:51:21 +08:00
hiyouga
933a084999 update examples
Former-commit-id: bf36b16e48d6438de6d0b2f2bfe33f7895699b9d
2024-04-02 20:41:49 +08:00
hiyouga
c1510d19c7 update readme
Former-commit-id: 9b8e7ccdab167f53fb897e1940562682324e8ff0
2024-04-02 20:37:37 +08:00
hiyouga
2074cf99fb update readme
Former-commit-id: 0c73d3c8a5762a8f119b27322ffd52a61de6fe38
2024-04-02 20:22:11 +08:00
hiyouga
b12176d818 simplify readme
Former-commit-id: 0da6ec2d516326fe9c7583ba71cd1778eb838178
2024-04-02 20:07:43 +08:00
hiyouga
117b67ea30 add moe aux loss control #3085
Former-commit-id: c9187ebc944e2de454ace3304b7d28eabb1b1a81
2024-04-02 14:26:31 +08:00
hiyouga
03e20bb5c6 fix #3022
Former-commit-id: dac2f617bda9470ac8d85c7e9def09cc04970506
2024-04-02 13:58:39 +08:00
hiyouga
0c4a1381a4 Update SECURITY.md
Former-commit-id: e22217c75421a89fd7e2ada62ce0e08245dd05e7
2024-04-01 23:30:03 +08:00
hiyouga
9e14501edb set dev version
Former-commit-id: 922ecae89210e5b8d62d78774f123a6d75c525ba
2024-04-01 23:24:08 +08:00
hiyouga
1dc963caa6 fix #3083
Former-commit-id: ff9a3f73961a362d0ddc22079f80a85465fffda8
2024-04-01 22:53:52 +08:00
hiyouga
85726c91ce add qwen1.5 moe
Former-commit-id: 3ea94f0d12cec25ac694a2c4ae8971c356990b61
2024-04-01 21:49:40 +08:00
hiyouga
40211db275 fix #3077
Former-commit-id: d0340391e8075cff0d84b3ef879c2101b66ca1dc
2024-04-01 21:35:18 +08:00
hiyouga
e7f13098c6 support infer 4bit model on GPUs #3023
Former-commit-id: 950a9dab9055839990656b2b40956792b253573d
2024-04-01 17:34:04 +08:00
hiyouga
61eb3a3d46 update webui
Former-commit-id: e96d260917a35ad2068f7b28b4f0b334b808ccc2
2024-04-01 16:23:28 +08:00
hiyouga
be0a807e8c fix ORPO loss
Former-commit-id: 5544ddde9087f00f9e20b78d0079f20c2f5d1604
2024-04-01 14:42:41 +08:00
hiyouga
52d402e2a9 fix IPO and ORPO loss
Former-commit-id: fc27955732aedbb12003faf19b760e2768b228f2
2024-04-01 14:37:53 +08:00
hiyouga
c5a46f9113 fix plots
Former-commit-id: 81355671296b84d438967463bb2a92934ff31aae
2024-03-31 19:43:48 +08:00
hiyouga
00e17a377c use log1p in orpo loss
https://github.com/huggingface/trl/pull/1491

Former-commit-id: 3b15d495264b00a4f8716bafea334778874963d7
2024-03-31 19:27:08 +08:00
hiyouga
9abd83adb1 update readme
Former-commit-id: 297b01f16ac78cde15a5d85a9a5b82ea20bfaf23
2024-03-31 18:46:34 +08:00
hoshi-hiyouga
f0d2afcf90 Merge pull request #3066 from hiyouga/orpo
support ORPO

Former-commit-id: fd4d3d29a9fae289f3bd0c4ce00656e4ccbec2e1
2024-03-31 18:42:48 +08:00
hiyouga
1aba442bcd support orpo in webui
Former-commit-id: dd5cc78d4fb18dd0a2e9d57f0f046cfe9f0dc2c9
2024-03-31 18:34:59 +08:00
hiyouga
d764cd8736 support ORPO
Former-commit-id: f44a4c27e2461cdaa1b16865f597a31033c0e6d9
2024-03-31 18:29:50 +08:00
hiyouga
526111a303 tiny fix
Former-commit-id: ba4a9b3c01e2f7467fbc5be268f47c0d003caa65
2024-03-31 00:10:29 +08:00
hoshi-hiyouga
b8364046df Merge pull request #3057 from marko1616/bugfix/lora-model-merge
Fix Llama model save for full param train

Former-commit-id: 18303e34d07dbf4c2dd9bac03243a3ed38582515
2024-03-31 00:07:20 +08:00
marko1616
1f617c6e08 fix blank line contains whitespace
Former-commit-id: 7bc3bcc64353d5a1d4870c6a9509b64cff710492
2024-03-30 23:46:55 +08:00
marko1616
a6858a36c0 Fix Llama model save for full param train
Former-commit-id: ca17b5db4f97c3ec9fe2004877f150e8f51ab4b5
2024-03-30 23:45:04 +08:00
hiyouga
6198121923 support save args in webui #2807 #3046
some ideas are borrowed from @marko1616


Former-commit-id: b5a062aa2d4a37670007e8b3dae5b6f5b7ffb15c
2024-03-30 23:09:12 +08:00
hiyouga
b0efebf853 upgrade gradio to 4.21.0
Former-commit-id: 63eecbeb967d849e1d03d8d03fb6421c0ee89257
2024-03-30 20:37:08 +08:00
hiyouga
fbd0584391 release v0.6.1
Former-commit-id: a59d823f554505b2e649e6e111b9dee8306d3ad8
2024-03-29 11:36:08 +08:00
hiyouga
50224b09cc update readme
Former-commit-id: 312d4f90784800dc8db4eaa7d908e6761115bc51
2024-03-28 22:02:32 +08:00
hiyouga
32dcc5a491 add project
Former-commit-id: 0418e9fecb2337b5d1b72e8358adb8aa10803c4b
2024-03-28 20:24:27 +08:00
hiyouga
9408366a36 fix #2982
Former-commit-id: e5e6a0c50c7a1c0052ed6b459450b9735ff2c9a1
2024-03-28 20:22:31 +08:00
hiyouga
f0e564beaa update readme
Former-commit-id: 6b634b5c2dbad827e8cc9850b8d7697c2056532a
2024-03-28 18:35:11 +08:00
hiyouga
14b75a0b93 fix #3010
Former-commit-id: a5e823ae75556eaa3b52ce7a887a6e7838a1eb5f
2024-03-28 18:31:17 +08:00
hiyouga
59e6ebf039 update trainers
Former-commit-id: d0dd6eefed0b86895ed00a7cafb331e5193db645
2024-03-28 18:16:27 +08:00
hoshi-hiyouga
dc540dfaa8 fix ds optimizer
Former-commit-id: 2675127070a1e7584e71039a11c1ebac54ddd1db
2024-03-26 23:39:56 +08:00
hiyouga
587e65e442 fix #2981
Former-commit-id: ede2a913856e52c0a96155705116528d3af15998
2024-03-26 17:53:04 +08:00
hiyouga
a916688723 fix bug
Former-commit-id: f513e1415cc3fe87f600318fba855d1286b6d007
2024-03-26 17:30:12 +08:00
hiyouga
3336422760 fix #2961
Former-commit-id: 616917bb3be7f71073b56ad8c7bc4e164b08b9b5
2024-03-26 17:26:14 +08:00
109 changed files with 2405 additions and 1905 deletions

2
.github/SECURITY.md vendored
View File

@@ -1,6 +1,6 @@
# Reporting Security Issues # Reporting Security Issues
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/electron/electron/security/advisories/new) tab. To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/hiyouga/LLaMA-Factory/security/advisories/new) tab.
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance. We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.

514
README.md
View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-26-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-34-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -44,16 +44,16 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO. - **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
## Benchmark ## Benchmark
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory. Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
![benchmark](assets/benchmark.svg) ![benchmark](assets/benchmark.svg)
@@ -62,24 +62,34 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024) - **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024) - **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024) - **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning. - We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
</details> </details>
## Changelog ## Changelog
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv! [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage. [24/04/19] We supported **Meta Llama 3** model series.
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm. [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer. [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage.
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training. [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `examples/extras/llama_pro` for usage. [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `examples/extras/llama_pro` for usage.
@@ -127,21 +137,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| Model | Model size | Default module | Template | | Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
@@ -165,9 +176,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> Use `--quantization_bit 4` argument to enable QLoRA.
## Provided Datasets ## Provided Datasets
@@ -242,12 +251,11 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
Please refer to [data/README.md](data/README.md) for details.
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands. Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
```bash ```bash
@@ -261,8 +269,8 @@ huggingface-cli login
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 | | transformers | 4.37.2 | 4.39.3 |
| datasets | 2.14.3 | 2.17.1 | | datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 | | accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 | | trl | 0.8.1 | 0.8.1 |
@@ -278,36 +286,39 @@ huggingface-cli login
\* *estimated* \* *estimated*
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B | | Method | Bits | 7B | 13B | 30B | 70B | 8x7B | 8x22B |
| ------ | ---- | ----- | ----- | ----- | ------ | ------ | | ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
## Getting Started ## Getting Started
### Data Preparation (optional) ### Data Preparation
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset. Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
> [!NOTE] > [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`. > Please update `data/dataset_info.json` to use your custom dataset.
### Dependence Installation (optional) ### Dependence Installation
```bash ```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10 conda create -n llama_factory python=3.10
conda activate llama_factory conda activate llama_factory
cd LLaMA-Factory cd LLaMA-Factory
pip install -r requirements.txt pip install -e .[metrics]
``` ```
Extra dependencies available: deepspeed, metrics, unsloth, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
<details><summary>For Windows users</summary>
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version. If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
```bash ```bash
@@ -316,378 +327,82 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements. To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
### Use ModelScope Hub (optional) </details>
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner. ### LLaMA Board GUI
```bash
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
```
Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path modelscope/Llama-2-7b-ms \
... # arguments (same as below)
```
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```
### Train on a single GPU
> [!IMPORTANT] > [!IMPORTANT]
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training). > LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
#### Use local environment
#### LLaMA Board GUI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
python src/train_web.py # or python -m llmtuner.webui.interface
``` ```
#### Pre-Training #### Use Docker
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### Supervised Fine-Tuning
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### Reward Modeling
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--top_k 0 \
--top_p 0.9 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
#### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
### Distributed Training
#### Use Huggingface Accelerate
```bash
accelerate launch --config_file config.yaml src/train_bash.py \
--ddp_timeout 180000000 \
... # arguments (same as above)
```
<details><summary>Example config.yaml for LoRA training</summary>
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
> [!TIP]
> We commend using Accelerate for LoRA tuning.
#### Use DeepSpeed
```bash
deepspeed --num_gpus 8 src/train_bash.py \
--deepspeed ds_config.json \
--ddp_timeout 180000000 \
... # arguments (same as above)
```
<details><summary>Example ds_config.json for full-parameter training with DeepSpeed ZeRO-2</summary>
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"round_robin_gradients": true
}
}
```
</details>
> [!TIP]
> Refer to [examples](examples) for more training scripts.
### Merge LoRA weights and export model
```bash
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora \
--export_dir path_to_export \
--export_size 2 \
--export_legacy_format False
```
> [!WARNING]
> Merging LoRA weights into a quantized model is not supported.
> [!TIP]
> Use `--model_name_or_path path_to_export` solely to use the exported model.
>
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model with AutoGPTQ after merging the LoRA weights.
### Inference with OpenAI-style API
```bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
> [!TIP]
> Visit `http://localhost:8000/docs` for API documentation.
### Inference with command line
```bash
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### Inference with web browser
```bash
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### Evaluation
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template vanilla \
--finetuning_type lora \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4
```
### Predict
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 1 \
--max_samples 100 \
--predict_with_generate \
--fp16
```
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 predict.
> [!TIP]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
### Dockerize Training
#### Get ready
Necessary dockerized environment is needed, such as Docker or Docker Compose.
#### Docker support
```bash ```bash
docker build -f ./Dockerfile -t llama-factory:latest . docker build -f ./Dockerfile -t llama-factory:latest .
docker run --gpus=all \
docker run --gpus=all -v ./hf_cache:/root/.cache/huggingface/ -v ./data:/app/data -v ./output:/app/output -p 7860:7860 --shm-size 16G --name llama_factory -d llama-factory:latest -v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \
-v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \
--shm-size 16G \
--name llama_factory \
-d llama-factory:latest
``` ```
#### Docker Compose support #### Use Docker Compose
```bash ```bash
docker compose -f ./docker-compose.yml up -d docker compose -f ./docker-compose.yml up -d
``` ```
> [!TIP] <details><summary>Details about volume</summary>
> Details about volume:
> * hf_cache: Utilize Huggingface cache on the host machine. Reassignable if a cache already exists in a different directory. - hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
> * data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI. - data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
> * output: Set export dir to this location so that the merged result can be accessed directly on the host machine. - output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
</details>
### Command Line Interface
See [examples/README.md](examples/README.md) for usage.
Use `python src/train_bash.py -h` to display arguments description.
### Deploy with OpenAI-style API and vLLM
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--template mistral \
--infer_backend vllm \
--vllm_enforce_eager
```
### Use ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
```bash
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
```
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `modelscope/Llama-2-7b-ms`.
## Projects using LLaMA Factory ## Projects using LLaMA Factory
If you have a project that should be incorporated, please contact via email or create a pull request.
<details><summary>Click to show</summary>
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223) 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092) 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) 1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
@@ -709,20 +424,27 @@ docker compose -f ./docker-compose.yml up -d
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
> [!TIP] </details>
> If you have a project that should be incorporated, please contact via email or create a pull request.
## License ## License
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation
@@ -730,7 +452,7 @@ If this work is helpful, please kindly cite as:
```bibtex ```bibtex
@article{zheng2024llamafactory, @article{zheng2024llamafactory,
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models}, title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma}, author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
journal={arXiv preprint arXiv:2403.13372}, journal={arXiv preprint arXiv:2403.13372},
year={2024}, year={2024},
@@ -740,7 +462,7 @@ If this work is helpful, please kindly cite as:
## Acknowledgement ## Acknowledgement
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works. This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
## Star History ## Star History

View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-26-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-34-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -44,16 +44,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色 ## 项目特色
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。 - **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
## 性能指标 ## 性能指标
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。 与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
![benchmark](assets/benchmark.svg) ![benchmark](assets/benchmark.svg)
@@ -62,24 +62,34 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4截断长度=1024 - **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4截断长度=1024
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4截断长度=1024 - **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4截断长度=1024
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1截断长度=1024 - **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1截断长度=1024
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32` - 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA Factory 的 LoRA 微调中采用 `lora_rank=32`
</details> </details>
## 更新日志 ## 更新日志
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看! [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora` [24/04/19] 我们支持了 **Meta Llama 3** 系列模型
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法 [24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器 [24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora`
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。 [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `examples/extras/llama_pro` [24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `examples/extras/llama_pro`
@@ -127,21 +137,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
@@ -165,9 +176,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
## 数据集 ## 数据集
@@ -242,12 +251,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。 部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
```bash ```bash
@@ -261,8 +269,8 @@ huggingface-cli login
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 | | transformers | 4.37.2 | 4.39.3 |
| datasets | 2.14.3 | 2.17.1 | | datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 | | accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 | | trl | 0.8.1 | 0.8.1 |
@@ -278,36 +286,39 @@ huggingface-cli login
\* *估算值* \* *估算值*
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B | | 方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B | 8x22B |
| ------- | ---- | ----- | ----- | ----- | ------ | ------ | | ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ |
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB |
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB |
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | | LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
## 如何使用 ## 如何使用
### 数据准备(可跳过) ### 数据准备
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件 关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集
> [!NOTE] > [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md` > 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
### 环境搭建(可跳过) ### 安装依赖
```bash ```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10 conda create -n llama_factory python=3.10
conda activate llama_factory conda activate llama_factory
cd LLaMA-Factory cd LLaMA-Factory
pip install -r requirements.txt pip install -e .[metrics]
``` ```
可选的额外依赖项deepspeed、metrics、unsloth、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
<details><summary>Windows 用户指南</summary>
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。 如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
```bash ```bash
@@ -316,7 +327,67 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
如果要在 Windows 平台上开启 FlashAttention-2需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。 如果要在 Windows 平台上开启 FlashAttention-2需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
### 使用魔搭社区(可跳过) </details>
### LLaMA Board 可视化界面
> [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行分布式训练。
#### 使用本地环境
```bash
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
python src/train_web.py # 或 python -m llmtuner.webui.interface
```
#### 使用 Docker
```bash
docker build -f ./Dockerfile -t llama-factory:latest .
docker run --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \
-v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \
--shm-size 16G \
--name llama_factory \
-d llama-factory:latest
```
#### 使用 Docker Compose
```bash
docker compose -f ./docker-compose.yml up -d
```
<details><summary>数据卷详情</summary>
- hf_cache使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- data宿主机中存放数据集的文件夹路径。
- output将导出目录设置为该路径后即可在宿主机中访问导出后的模型。
</details>
### 命令行接口
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
使用 `python src/train_bash.py -h` 查看参数文档。
### 使用 OpenAI 风格 API 和 vLLM 部署
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--template mistral \
--infer_backend vllm \
--vllm_enforce_eager
```
### 使用魔搭社区
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
@@ -324,343 +395,14 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
``` ```
接着即可通过指定模型名称来训练对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型 `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `modelscope/Llama-2-7b-ms`
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path modelscope/Llama-2-7b-ms \
... # 参数同下
```
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```
### 单 GPU 训练
> [!IMPORTANT]
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
#### LLaMA Board GUI
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
#### 预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 奖励模型训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--top_k 0 \
--top_p 0.9 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
> [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
#### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate
```bash
accelerate launch --config_file config.yaml src/train_bash.py \
--ddp_timeout 180000000 \
... # 参数同上
```
<details><summary>使用 Accelerate 进行 LoRA 训练的 config.yaml 示例</summary>
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
> [!TIP]
> 我们推荐使用 Accelerate 进行 LoRA 训练。
#### 使用 DeepSpeed
```bash
deepspeed --num_gpus 8 src/train_bash.py \
--deepspeed ds_config.json \
--ddp_timeout 180000000 \
... # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 ds_config.json 示例</summary>
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"round_robin_gradients": true
}
}
```
</details>
> [!TIP]
> 更多训练脚本请查看 [examples](examples)。
### 合并 LoRA 权重并导出模型
```bash
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora \
--export_dir path_to_export \
--export_size 2 \
--export_legacy_format False
```
> [!WARNING]
> 尚不支持量化模型的 LoRA 权重合并及导出。
> [!TIP]
> 仅使用 `--model_name_or_path path_to_export` 来加载导出后的模型。
>
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 基于 AutoGPTQ 量化模型。
### 使用 OpenAI 风格 API 推理
```bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
> [!TIP]
> 关于 API 文档请见 `http://localhost:8000/docs`。
### 使用命令行推理
```bash
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### 使用浏览器推理
```bash
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### 模型评估
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template vanilla \
--finetuning_type lora \
--task ceval \
--split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
```
### 模型预测
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 1 \
--max_samples 100 \
--predict_with_generate \
--fp16
```
> [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的预测,请使用 `--per_device_eval_batch_size=1`。
> [!TIP]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
<details><summary>点击显示</summary>
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223) 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092) 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) 1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
@@ -682,20 +424,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
> [!TIP] </details>
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
## 协议 ## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用
@@ -713,7 +462,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
## 致谢 ## 致谢
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。 本项目受益于 [PEFT](https://github.com/huggingface/peft)、[TRL](https://github.com/huggingface/trl)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
## Star History ## Star History

View File

@@ -34,6 +34,8 @@ If you are using a custom dataset, please provide your dataset definition in the
Given above, you can use the custom dataset via specifying `--dataset dataset_name`. Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
----
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format: Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
```json ```json
@@ -84,6 +86,10 @@ For the preference datasets, the `response` column should be a string list whose
} }
``` ```
Remember to set `"ranking": true` for the preference datasets.
----
The dataset in sharegpt format should follow the below format: The dataset in sharegpt format should follow the below format:
```json ```json

View File

@@ -34,6 +34,8 @@
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。 添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
----
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织: 该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
```json ```json
@@ -84,6 +86,10 @@
} }
``` ```
添加偏好数据集需要额外指定 `"ranking": true`
----
而 sharegpt 格式的数据集按照以下方式组织: 而 sharegpt 格式的数据集按照以下方式组织:
```json ```json

View File

@@ -1 +1 @@
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b a97cf9475291591843976554878568e046d8a46d

View File

@@ -1,5 +1,6 @@
import os
import json import json
import os
import datasets import datasets
@@ -22,31 +23,19 @@ _URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0
class BelleMultiturn(datasets.GeneratorBasedBuilder): class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self): def _info(self):
features = datasets.Features({ features = datasets.Features(
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}] {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
}) )
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager): def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download(_URL) file_path = dl_manager.download(_URL)
return [ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
def _generate_examples(self, filepath: str): def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:
@@ -58,7 +47,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:") assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:") human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip() query = prompt[human_idx + 6 : assist_idx].strip()
prompt = prompt[:human_idx].strip() prompt = prompt[:human_idx].strip()
conversations.insert(0, {"from": "gpt", "value": response}) conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query}) conversations.insert(0, {"from": "human", "value": query})
@@ -67,8 +56,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:") assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:") human_idx = prompt.rfind("Human:")
if human_idx != -1: if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip() old_query = prompt[human_idx + 6 : assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip() old_resp = prompt[assist_idx + 10 :].strip()
conversations.insert(0, {"from": "gpt", "value": old_resp}) conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query}) conversations.insert(0, {"from": "human", "value": old_query})
else: else:

View File

@@ -1,7 +1,8 @@
import json import json
import datasets
from typing import Any, Dict, Generator, List, Tuple from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset." _DESCRIPTION = "An example of dataset."
_CITATION = "" _CITATION = ""
@@ -11,34 +12,24 @@ _URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder): class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo: def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({ features = datasets.Features(
"instruction": datasets.Value("string"), {
"input": datasets.Value("string"), "instruction": datasets.Value("string"),
"output": datasets.Value("string"), "input": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) "output": datasets.Value("string"),
}) "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL) file_path = dl_manager.download(_URL)
return [ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]: def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8")) example_dataset = json.load(open(filepath, "r", encoding="utf-8"))

View File

@@ -1,8 +1,10 @@
import os
import json import json
import datasets import os
from typing import List from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness." _DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = "" _CITATION = ""
@@ -14,50 +16,37 @@ _URLS = {
_URL + "harmless-base/train.jsonl.gz", _URL + "harmless-base/train.jsonl.gz",
_URL + "helpful-base/train.jsonl.gz", _URL + "helpful-base/train.jsonl.gz",
_URL + "helpful-online/train.jsonl.gz", _URL + "helpful-online/train.jsonl.gz",
_URL + "helpful-rejection-sampled/train.jsonl.gz" _URL + "helpful-rejection-sampled/train.jsonl.gz",
], ],
"test": [ "test": [
_URL + "harmless-base/test.jsonl.gz", _URL + "harmless-base/test.jsonl.gz",
_URL + "helpful-base/test.jsonl.gz", _URL + "helpful-base/test.jsonl.gz",
_URL + "helpful-online/test.jsonl.gz", _URL + "helpful-online/test.jsonl.gz",
_URL + "helpful-rejection-sampled/test.jsonl.gz" _URL + "helpful-rejection-sampled/test.jsonl.gz",
] ],
} }
class HhRlhfEn(datasets.GeneratorBasedBuilder): class HhRlhfEn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo: def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({ features = datasets.Features(
"instruction": datasets.Value("string"), {
"output": datasets.Sequence(datasets.Value("string")), "instruction": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) "output": datasets.Sequence(datasets.Value("string")),
}) "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager): def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download_and_extract(_URLS) file_path = dl_manager.download_and_extract(_URLS)
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
name=datasets.Split.TRAIN, datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
gen_kwargs={
"filepaths": file_path["train"]
}
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepaths": file_path["test"]
}
)
] ]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
@@ -70,12 +59,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
rejected = data["rejected"] rejected = data["rejected"]
assist_idx = rejected.rfind("\n\nAssistant: ") assist_idx = rejected.rfind("\n\nAssistant: ")
r_reject = rejected[assist_idx+13:].strip() r_reject = rejected[assist_idx + 13 :].strip()
assist_idx = chosen.rfind("\n\nAssistant: ") assist_idx = chosen.rfind("\n\nAssistant: ")
r_accept = chosen[assist_idx+13:].strip() r_accept = chosen[assist_idx + 13 :].strip()
human_idx = chosen.rfind("\n\nHuman: ") human_idx = chosen.rfind("\n\nHuman: ")
query = chosen[human_idx+9:assist_idx].strip() query = chosen[human_idx + 9 : assist_idx].strip()
prompt = chosen[:human_idx] prompt = chosen[:human_idx]
history = [] history = []
@@ -83,16 +72,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("\n\nAssistant: ") assist_idx = prompt.rfind("\n\nAssistant: ")
human_idx = prompt.rfind("\n\nHuman: ") human_idx = prompt.rfind("\n\nHuman: ")
if human_idx != -1: if human_idx != -1:
old_query = prompt[human_idx+9:assist_idx].strip() old_query = prompt[human_idx + 9 : assist_idx].strip()
old_resp = prompt[assist_idx+13:].strip() old_resp = prompt[assist_idx + 13 :].strip()
history.insert(0, (old_query, old_resp)) history.insert(0, (old_query, old_resp))
else: else:
break break
prompt = prompt[:human_idx] prompt = prompt[:human_idx]
yield key, { yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
"instruction": query,
"output": [r_accept, r_reject],
"history": history
}
key += 1 key += 1

View File

@@ -1,8 +1,10 @@
import os
import json import json
import datasets import os
from typing import List from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data." _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
@@ -24,31 +26,19 @@ _BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jso
class UltraChat(datasets.GeneratorBasedBuilder): class UltraChat(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
def _info(self): def _info(self):
features = datasets.Features({ features = datasets.Features(
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}] {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
}) )
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
) )
def _split_generators(self, dl_manager: datasets.DownloadManager): def _split_generators(self, dl_manager: datasets.DownloadManager):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_paths
}
)
]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths: for filepath in filepaths:
@@ -56,7 +46,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
for row in f: for row in f:
try: try:
data = json.loads(row) data = json.loads(row)
except: except Exception:
continue continue
key: int = data["id"] key: int = data["id"]
content: List[str] = data["data"] content: List[str] = data["data"]
@@ -64,8 +54,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
content.pop(-1) content.pop(-1)
if len(content) < 2: if len(content) < 2:
continue continue
conversations = [{ conversations = [
"from": "human" if i % 2 == 0 else "gpt", {"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
"value": content[i] ]
} for i in range(len(content))]
yield key, {"conversations": conversations} yield key, {"conversations": conversations}

View File

@@ -10,6 +10,8 @@ services:
- ./hf_cache:/root/.cache/huggingface/ - ./hf_cache:/root/.cache/huggingface/
- ./data:/app/data - ./data:/app/data
- ./output:/app/output - ./output:/app/output
environment:
- CUDA_VISIBLE_DEVICES=0
ports: ports:
- "7860:7860" - "7860:7860"
ipc: host ipc: host

48
examples/README.md Normal file
View File

@@ -0,0 +1,48 @@
We provide diverse examples about fine-tuning LLMs.
```
examples/
├── lora_single_gpu/
│ ├── pretrain.sh: Do continuous pre-training using LoRA
│ ├── sft.sh: Do supervised fine-tuning using LoRA
│ ├── reward.sh: Do reward modeling using LoRA
│ ├── ppo.sh: Do PPO training using LoRA
│ ├── dpo.sh: Do DPO training using LoRA
│ ├── orpo.sh: Do ORPO training using LoRA
│ ├── prepare.sh: Save tokenized dataset
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
├── qlora_single_gpu/
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models using QLoRA
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models using QLoRA
│ ├── awq.sh: Fine-tune 4-bit AWQ models using QLoRA
│ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
├── lora_multi_gpu/
│ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
├── full_multi_gpu/
│ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
│ ├── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after full tuning
├── merge_lora/
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
├── inference/
│ ├── cli_demo.sh: Launch a command line interface with LoRA adapters
│ ├── api_demo.sh: Launch an OpenAI-style API with LoRA adapters
│ ├── web_demo.sh: Launch a web interface with LoRA adapters
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
└── extras/
├── galore/
│ └── sft.sh: Fine-tune model with GaLore
├── badam/
│ └── sft.sh: Fine-tune model with BAdam
├── loraplus/
│ └── sft.sh: Fine-tune model using LoRA+
├── mod/
│ └── sft.sh: Fine-tune model using Mixture-of-Depths
├── llama_pro/
│ ├── expand.sh: Expand layers in the model
│ └── sft.sh: Fine-tune the expanded model
└── fsdp_qlora/
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
```

48
examples/README_zh.md Normal file
View File

@@ -0,0 +1,48 @@
我们提供了多样化的大模型微调示例脚本。
```
examples/
├── lora_single_gpu/
│ ├── pretrain.sh: 基于 LoRA 进行增量预训练
│ ├── sft.sh: 基于 LoRA 进行指令监督微调
│ ├── reward.sh: 基于 LoRA 进行奖励模型训练
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
│ ├── prepare.sh: 保存预处理后的数据集
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
├── qlora_single_gpu/
│ ├── bitsandbytes.sh: 基于 QLoRA 微调 4/8 比特 BNB 模型
│ ├── gptq.sh: 基于 QLoRA 微调 4/8 比特 GPTQ 模型
│ ├── awq.sh: 基于 QLoRA 微调 4 比特 AWQ 模型
│ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
├── lora_multi_gpu/
│ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
│ └── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
├── full_multi_gpu/
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
│ ├── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
│ └── predict.sh: 基于全量训练进行批量预测并计算 BLEU 和 ROUGE 分数
├── merge_lora/
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
│ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型
├── inference/
│ ├── cli_demo.sh: 启动 LoRA 模型的命令行推理接口
│ ├── api_demo.sh: 启动 LoRA 模型的 OpenAI 风格 API
│ ├── web_demo.sh: 启动 LoRA 模型的浏览器推理接口
│ └── evaluate.sh: 在 MMLU/CMMLU/C-Eval 数据集上评测 LoRA 模型
└── extras/
├── galore/
│ └── sft.sh: 使用 GaLore 训练模型
├── badam/
│ └── sft.sh: 使用 BAdam 训练模型
├── loraplus/
│ └── sft.sh: 使用 LoRA+ 训练模型
├── mod/
│ └── sft.sh: 使用深度混合训练模型
├── llama_pro/
│ ├── expand.sh: 扩展模型中的层
│ └── sft.sh: 训练扩展后的模型
└── fsdp_qlora/
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
```

View File

@@ -9,14 +9,14 @@ fsdp_config:
fsdp_forward_prefetch: false fsdp_forward_prefetch: false
fsdp_offload_params: true fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true fsdp_sync_module_states: true
fsdp_use_orig_params: false fsdp_use_orig_params: false
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
num_machines: 1 num_machines: 1 # the number of nodes
num_processes: 2 num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@@ -8,8 +8,8 @@ main_process_ip: 192.168.0.1
main_process_port: 29555 main_process_port: 29555
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
num_machines: 2 num_machines: 2 # the number of nodes
num_processes: 16 num_processes: 16 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@@ -6,8 +6,8 @@ gpu_ids: all
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
num_machines: 1 num_machines: 1 # the number of nodes
num_processes: 4 num_processes: 4 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@@ -8,8 +8,8 @@ main_process_ip: 192.168.0.1
main_process_port: 29555 main_process_port: 29555
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16
num_machines: 2 num_machines: 2 # the number of nodes
num_processes: 16 num_processes: 16 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@@ -8,15 +8,16 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset_dir ../../../data \ --dataset_dir ../../../data \
--template default \ --template default \
--finetuning_type full \ --finetuning_type full \
--optim adamw_8bit \ --mixture_of_depths convert \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \ --output_dir ../../../saves/LLaMA2-7B/mod/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
--cutoff_len 1024 \ --cutoff_len 1024 \
--preprocessing_num_workers 16 \ --preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \ --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 8 \
--optim paged_adamw_8bit \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
--warmup_steps 20 \ --warmup_steps 20 \

View File

@@ -8,14 +8,18 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset_dir ../../../data \ --dataset_dir ../../../data \
--template default \ --template default \
--finetuning_type full \ --finetuning_type full \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \ --use_badam \
--badam_switch_mode descending \
--badam_switch_block_every 50 \
--badam_verbose 2 \
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
--cutoff_len 1024 \ --cutoff_len 1024 \
--preprocessing_num_workers 16 \ --preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \ --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
--warmup_steps 20 \ --warmup_steps 20 \
@@ -28,4 +32,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--plot_loss \ --plot_loss \
--fp16 --pure_bf16

View File

@@ -1,25 +1,31 @@
#!/bin/bash #!/bin/bash
pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file ../accelerate/fsdp_config.yaml \ --config_file ../../accelerate/fsdp_config.yaml \
../../src/train_bash.py \ ../../../src/train_bash.py \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-70b-hf \ --model_name_or_path meta-llama/Llama-2-70b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \ --dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \ --dataset_dir ../../../data \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-70B/lora/sft \ --output_dir ../../../saves/LLaMA2-70B/lora/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
--cutoff_len 1024 \ --cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \ --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \ --save_steps 100 \
--eval_steps 100 \ --eval_steps 100 \
--evaluation_strategy steps \ --evaluation_strategy steps \
@@ -28,6 +34,7 @@ CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--num_train_epochs 3.0 \ --num_train_epochs 3.0 \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--ddp_timeout 180000000 \
--quantization_bit 4 \ --quantization_bit 4 \
--plot_loss \ --plot_loss \
--fp16 --fp16

View File

@@ -8,11 +8,11 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset_dir ../../../data \ --dataset_dir ../../../data \
--template default \ --template default \
--finetuning_type full \ --finetuning_type full \
--optim adamw_8bit \
--use_galore \ --use_galore \
--galore_layerwise \ --galore_layerwise \
--galore_target mlp,self_attn \ --galore_target mlp,self_attn \
--galore_rank 128 \ --galore_rank 128 \
--galore_scale 2.0 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \ --output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \

View File

@@ -9,6 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--loraplus_lr_ratio 16.0 \
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \ --output_dir ../../saves/LLaMA2-7B/loraplus/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
@@ -29,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--plot_loss \ --plot_loss \
--fp16 \ --fp16
--loraplus_lr_ratio 16.0

View File

@@ -1,5 +0,0 @@
```bash
pip install git+https://github.com/huggingface/transformers.git
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
```

View File

@@ -33,6 +33,6 @@ python -m torch.distributed.run \
--num_train_epochs 3.0 \ --num_train_epochs 3.0 \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--ddp_timeout 1800000 \ --ddp_timeout 180000000 \
--plot_loss \ --plot_loss \
--fp16 --fp16

View File

@@ -0,0 +1,18 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type full \
--output_dir ../../saves/LLaMA2-7B/full/predict \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_eval_batch_size 1 \
--max_samples 20 \
--predict_with_generate

View File

@@ -27,6 +27,6 @@ deepspeed --num_gpus 4 ../../src/train_bash.py \
--num_train_epochs 3.0 \ --num_train_epochs 3.0 \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--ddp_timeout 1800000 \ --ddp_timeout 180000000 \
--plot_loss \ --plot_loss \
--fp16 --fp16

View File

@@ -0,0 +1,7 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python ../../src/api_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \
--finetuning_type lora

View File

@@ -0,0 +1,7 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/cli_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \
--finetuning_type lora

View File

@@ -0,0 +1,12 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template fewshot \
--finetuning_type lora \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4

View File

@@ -0,0 +1,7 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \
--finetuning_type lora

View File

@@ -30,6 +30,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--num_train_epochs 3.0 \ --num_train_epochs 3.0 \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--ddp_timeout 1800000 \ --ddp_timeout 180000000 \
--plot_loss \ --plot_loss \
--fp16 --fp16

View File

@@ -30,6 +30,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--num_train_epochs 3.0 \ --num_train_epochs 3.0 \
--max_samples 3000 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--ddp_timeout 1800000 \ --ddp_timeout 180000000 \
--plot_loss \ --plot_loss \
--fp16 --fp16

View File

@@ -1,8 +0,0 @@
Usage:
- `pretrain.sh`: do pre-train (optional)
- `sft.sh`: do supervised fine-tune
- `reward.sh`: do reward modeling (must after sft.sh)
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
- `dpo.sh`: do DPO training (must after sft.sh)
- `predict.sh`: do predict (must after sft.sh and dpo.sh)

View File

@@ -6,7 +6,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--create_new_adapter \ --create_new_adapter \
--dataset comparison_gpt4_en \ --dataset orca_rlhf \
--dataset_dir ../../data \ --dataset_dir ../../data \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \

View File

@@ -1,25 +1,22 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \ --stage orpo \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \ --dataset orca_rlhf \
--dataset_dir ../../../data \ --dataset_dir ../../data \
--template default \ --template default \
--finetuning_type full \ --finetuning_type lora \
--use_galore \ --lora_target q_proj,v_proj \
--galore_layerwise \ --output_dir ../../saves/LLaMA2-7B/lora/orpo \
--galore_target mlp,self_attn \
--galore_rank 128 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
--cutoff_len 1024 \ --cutoff_len 1024 \
--preprocessing_num_workers 16 \ --preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \ --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
--warmup_steps 20 \ --warmup_steps 20 \
@@ -27,9 +24,9 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--eval_steps 100 \ --eval_steps 100 \
--evaluation_strategy steps \ --evaluation_strategy steps \
--load_best_model_at_end \ --load_best_model_at_end \
--learning_rate 5e-5 \ --learning_rate 1e-5 \
--num_train_epochs 3.0 \ --num_train_epochs 1.0 \
--max_samples 3000 \ --max_samples 1000 \
--val_size 0.1 \ --val_size 0.1 \
--plot_loss \ --plot_loss \
--fp16 --fp16

View File

@@ -0,0 +1,18 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES= python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--max_samples 3000 \
--tokenized_path ../../saves/datasets/sft

View File

@@ -6,7 +6,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--create_new_adapter \ --create_new_adapter \
--dataset comparison_gpt4_en \ --dataset orca_rlhf \
--dataset_dir ../../data \ --dataset_dir ../../data \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \

View File

@@ -1,4 +0,0 @@
Usage:
- `merge.sh`: merge the lora weights
- `quantize.sh`: quantize the model with AutoGPTQ (must after merge.sh, optional)

View File

@@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
# DO NOT use quantized model or quantization_bit when merging lora weights
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@@ -2,9 +2,9 @@ torch>=1.13.1
transformers>=4.37.2 transformers>=4.37.2
datasets>=2.14.3 datasets>=2.14.3
accelerate>=0.27.2 accelerate>=0.27.2
peft>=0.9.0 peft>=0.10.0
trl>=0.8.1 trl>=0.8.1
gradio>=3.38.0,<4.0.0 gradio>=4.0.0
scipy scipy
einops einops
sentencepiece sentencepiece
@@ -15,4 +15,3 @@ fastapi
sse-starlette sse-starlette
matplotlib matplotlib
fire fire
galore-torch

View File

@@ -15,7 +15,7 @@ from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer from llmtuner.model import load_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@@ -32,7 +32,7 @@ def calculate_lr(
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate, is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
): ):
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage=stage, stage=stage,
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
@@ -44,8 +44,8 @@ def calculate_lr(
overwrite_cache=True, overwrite_cache=True,
) )
) )
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False) tokenizer = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage) trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage)
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft": elif stage == "sft":

View File

@@ -10,7 +10,7 @@ from tqdm import tqdm
from llmtuner.data import get_dataset from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer from llmtuner.model import load_tokenizer
def length_cdf( def length_cdf(
@@ -20,7 +20,7 @@ def length_cdf(
template: Optional[str] = "default", template: Optional[str] = "default",
interval: Optional[int] = 1000, interval: Optional[int] = 1000,
): ):
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage="sft", stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
@@ -32,7 +32,7 @@ def length_cdf(
overwrite_cache=True, overwrite_cache=True,
) )
) )
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False) tokenizer = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
total_num = len(trainset) total_num = len(trainset)
length_dict = defaultdict(int) length_dict = defaultdict(int)

View File

@@ -1,114 +0,0 @@
# coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
internlm2_config_dict: Dict[str, Any] = json.load(f)
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
internlm2_state_dict.update(shard_weight)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
if "output" in key:
llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
elif "wqkv" in key:
num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
q_size : q_size + kv_size, ...
]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "ffn_norm" in key:
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
elif "w1" in key:
llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value
elif "w2" in key:
llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value
elif "w3" in key:
llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
else:
llama2_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None)
llama2_config_dict.pop("bias", None)
llama2_config_dict.pop("rope_scaling", None)
llama2_config_dict["model_type"] = "llama"
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir)
if __name__ == "__main__":
fire.Fire(llamafy_internlm2)

View File

@@ -20,15 +20,18 @@ def get_requires():
extra_require = { extra_require = {
"deepspeed": ["deepspeed"], "deepspeed": ["deepspeed>=0.10.0"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"], "unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
"galore": ["galore-torch"],
"badam": ["badam"],
"vllm": ["vllm>=0.3.3"], "vllm": ["vllm>=0.3.3"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["tiktoken", "transformers_stream_generator"], "qwen": ["tiktoken", "transformers_stream_generator"],
"modelscope": ["modelscope"],
"quality": ["ruff"], "quality": ["ruff"],
} }

View File

@@ -2,8 +2,7 @@ from llmtuner import Evaluator
def main(): def main():
evaluator = Evaluator() Evaluator().eval()
evaluator.eval()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo from .webui import create_ui, create_web_demo
__version__ = "0.6.0" __version__ = "0.6.3"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
input_messages.append({"role": role_mapping[message.role], "content": message.content}) if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
else:
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False) tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:

View File

@@ -1,6 +1,6 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
@@ -39,6 +39,17 @@ class Function(BaseModel):
arguments: str arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default" id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function" type: Literal["function"] = "function"
@@ -47,7 +58,8 @@ class FunctionCall(BaseModel):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: str content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
@@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
tools: list = [] tools: Optional[List[FunctionAvailable]] = None
do_sample: bool = True do_sample: bool = True
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None

View File

@@ -5,14 +5,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass @dataclass
class Response: class Response:

View File

@@ -9,7 +9,7 @@ from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model_and_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -30,11 +30,12 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer( self.tokenizer = load_tokenizer(model_args)
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
@staticmethod @staticmethod

View File

@@ -1,8 +1,6 @@
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from transformers.utils.versions import require_version
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available
@@ -25,7 +23,6 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path, model=model_args.model_name_or_path,

View File

@@ -1,6 +1,15 @@
from .collator import PairwiseDataCollatorWithPadding
from .loader import get_dataset from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset from .utils import Role, split_dataset
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"] __all__ = [
"PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",
"templates",
"Role",
"split_dataset",
]

View File

@@ -6,12 +6,15 @@ from transformers import DataCollatorForSeq2Seq
@dataclass @dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
""" """
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
r"""
Masks out the input ids except for the responses.
"""
padded_labels = [] padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions): for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left": if self.tokenizer.padding_side == "left":
@@ -43,12 +46,6 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
) )
label_positions.append((prompt_len, answer_len)) label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad( batch = super().__call__(concatenated_features)
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch return batch

View File

@@ -6,6 +6,7 @@ from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
@@ -80,7 +81,9 @@ def load_single_dataset(
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset() )
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError("Please install modelscope via `pip install modelscope -U`")
else: else:
@@ -117,17 +120,17 @@ def get_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template) template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache # Load tokenized dataset
if data_args.cache_path is not None: if data_args.tokenized_path is not None:
if os.path.exists(data_args.cache_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path) dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming: if data_args.streaming:
dataset = dataset.to_iterable_dataset() dataset = dataset.to_iterable_dataset()
return dataset return dataset
@@ -138,6 +141,9 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
all_datasets = [] all_datasets = []
for dataset_attr in get_dataset_list(data_args): for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
@@ -156,10 +162,13 @@ def get_dataset(
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): if data_args.tokenized_path is not None:
if training_args.should_save: if training_args.should_save:
dataset.save_to_disk(data_args.cache_path) dataset.save_to_disk(data_args.tokenized_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path)) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
exit(0)
if training_args.should_log: if training_args.should_log:
try: try:

View File

@@ -53,22 +53,35 @@ class DatasetAttr:
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else [] if data_args.dataset is not None:
try: dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: else:
dataset_info = json.load(f) dataset_names = []
except Exception as err:
if data_args.dataset is not None: if data_args.dataset_dir == "ONLINE":
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None dataset_info = None
else:
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if len(dataset_names) != 0:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None: if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = [] dataset_list: List[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None:
load_from = "ms_hub" if use_modelscope() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))

View File

@@ -23,23 +23,25 @@ def preprocess_pretrain_dataset(
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) if not data_args.packing:
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} if data_args.template == "gemma":
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) text_examples = [tokenizer.bos_token + example for example in text_examples]
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
total_length = (total_length // block_size) * block_size else:
# split by chunks of cutoff_len tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
result = { concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
for k, t in concatenated_examples.items() block_size = data_args.cutoff_len
} total_length = (total_length // block_size) * block_size
if data_args.template == "gemma": result = {
for i in range(len(result["input_ids"])): k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
result["input_ids"][i][0] = tokenizer.bos_token_id for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result return result

View File

@@ -343,7 +343,7 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None, name: Optional[str] = None,
) -> Template: ) -> Template:
if name is None: if name is None:
template = templates["vanilla"] # placeholder template = templates["empty"] # placeholder
else: else:
template = templates.get(name, None) template = templates.get(name, None)
if template is None: if template is None:
@@ -385,7 +385,8 @@ _register_template(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=( default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request." "Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
), ),
) )
@@ -414,7 +415,7 @@ _register_template(
_register_template( _register_template(
name="baichuan", name="baichuan",
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]), format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True, efficient_eos=True,
) )
@@ -441,6 +442,18 @@ _register_template(
) )
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
default_system=(
"You are a helpful AI assistant built by MediaTek Research. "
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
),
efficient_eos=True,
)
_register_template( _register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@@ -514,6 +527,21 @@ _register_template(
) )
_register_template(
name="cohere",
format_user=StringFormatter(
slots=[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
)
_register_template( _register_template(
name="cpm", name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]), format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@@ -554,6 +582,13 @@ _register_template(
) )
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -562,6 +597,13 @@ _register_template(
) )
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template( _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
@@ -623,9 +665,28 @@ _register_template(
) )
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
default_system="You are a helpful assistant.",
stop_words=["<|eot_id|>"],
replace_eos=True,
)
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True, force_system=True,
) )
@@ -687,11 +748,6 @@ _register_template(
) )
_register_template(
name="vanilla",
)
_register_template( _register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),

View File

@@ -44,7 +44,7 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len))) max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len) max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len return max_source_len, max_target_len

View File

@@ -14,16 +14,17 @@ from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import CHOICES, SUBJECTS from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args from ..hparams import get_eval_args
from ..model import load_model_and_tokenizer from ..model import load_model, load_tokenizer
from .template import get_eval_template from .template import get_eval_template
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args) self.tokenizer = load_tokenizer(self.model_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [ self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES

View File

@@ -1,14 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import Dict, List, Sequence, Tuple
from ..data import Role from ..data import Role
from ..extras.constants import CHOICES from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass @dataclass
class EvalTemplate: class EvalTemplate:
system: str system: str
@@ -16,22 +12,29 @@ class EvalTemplate:
answer: str answer: str
prefix: str prefix: str
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"] return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example( def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = [] messages = []
for k in range(len(support_set)): for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k]) prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt}) messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response}) messages.append({"role": Role.ASSISTANT.value, "content": response})
prompt, response = self.parse_example(target_data) prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt}) messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response}) messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages return messages
@@ -39,7 +42,7 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
@@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
return eval_template return eval_template
register_eval_template( _register_eval_template(
name="en", name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
@@ -58,10 +61,10 @@ register_eval_template(
) )
register_eval_template( _register_eval_template(
name="zh", name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:", answer="\n答案:",
prefix="\n", prefix=" ",
) )

View File

@@ -58,9 +58,17 @@ class LogCallback(TrainerCallback):
self.in_training = True self.in_training = True
self.start_time = time.time() self.start_time = time.time()
self.max_steps = state.max_steps self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.") if args.save_on_each_node:
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
@@ -112,8 +120,12 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_local_process_zero: if args.save_on_each_node:
return if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,
@@ -122,6 +134,7 @@ class LogCallback(TrainerCallback):
eval_loss=state.log_history[-1].get("eval_loss", None), eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None), predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None), reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None), learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,

View File

@@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"] PEFT_METHODS = ["lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
@@ -39,9 +41,12 @@ TRAINING_STAGES = {
"Reward Modeling": "rm", "Reward Modeling": "rm",
"PPO": "ppo", "PPO": "ppo",
"DPO": "dpo", "DPO": "dpo",
"ORPO": "orpo",
"Pre-Training": "pt", "Pre-Training": "pt",
} }
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
@@ -167,6 +172,19 @@ register_model_group(
) )
register_model_group(
models={
"Breeze-7B": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
},
"Breeze-7B-Chat": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
},
},
template="breeze",
)
register_model_group( register_model_group(
models={ models={
"ChatGLM2-6B-Chat": { "ChatGLM2-6B-Chat": {
@@ -226,6 +244,28 @@ register_model_group(
) )
register_model_group(
models={
"CommandR-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
},
"CommandR-Plus-104B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
},
"CommandR-35B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
},
"CommandR-Plus-104B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
},
},
template="cohere",
)
register_model_group( register_model_group(
models={ models={
"DeepSeek-LLM-7B-Base": { "DeepSeek-LLM-7B-Base": {
@@ -347,6 +387,23 @@ register_model_group(
) )
register_model_group(
models={
"CodeGemma-2B": {
DownloadSource.DEFAULT: "google/codegemma-2b",
},
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
},
template="gemma",
)
register_model_group( register_model_group(
models={ models={
"InternLM-7B": { "InternLM-7B": {
@@ -460,14 +517,41 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Mistral-7B": { "LLaMA3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
},
"LLaMA3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
},
"LLaMA3-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
},
"LLaMA3-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
},
},
template="llama3",
)
register_model_group(
models={
"Mistral-7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
}, },
"Mistral-7B-Chat": { "Mistral-7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
}, },
"Mistral-7B-v0.2": {
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
},
"Mistral-7B-v0.2-Chat": { "Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
@@ -479,14 +563,20 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Mixtral-8x7B": { "Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
}, },
"Mixtral-8x7B-Chat": { "Mixtral-8x7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
}, },
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
},
"Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
},
}, },
template="mistral", template="mistral",
) )
@@ -656,10 +746,22 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
}, },
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": { "Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
}, },
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-Code-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"Qwen1.5-0.5B-Chat": { "Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
@@ -680,10 +782,22 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
}, },
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": { "Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
}, },
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-Code-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": { "Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
@@ -724,6 +838,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
}, },
"Qwen1.5-32B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-int8-Chat": { "Qwen1.5-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
@@ -732,6 +850,14 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
}, },
"Qwen1.5-MoE-A2.7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"Qwen1.5-Code-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
},
}, },
template="qwen", template="qwen",
) )
@@ -935,18 +1061,3 @@ register_model_group(
}, },
template="zephyr", template="zephyr",
) )
register_model_group(
models={
"Atom-7B": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
},
"Atom-7B-Chat": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
},
},
template="atom",
)

View File

@@ -64,7 +64,7 @@ def check_dependencies() -> None:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0") require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
@@ -83,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if param.__class__.__name__ == "Params4bit": if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else: else:
num_bytes = 1 num_bytes = 1
@@ -192,6 +194,13 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32 return torch.float32
def has_tokenized_data(path: os.PathLike) -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""
Collects GPU memory. Collects GPU memory.
@@ -202,17 +211,15 @@ def torch_gc() -> None:
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def try_download_model_from_ms(model_args: "ModelArguments") -> None: def try_download_model_from_ms(model_args: "ModelArguments") -> str:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path): if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
return return model_args.model_name_or_path
try: try:
from modelscope import snapshot_download from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download( return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
)
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError("Please install modelscope via `pip install modelscope -U`")

View File

@@ -25,6 +25,10 @@ def is_galore_available():
return _is_package_available("galore_torch") return _is_package_available("galore_torch")
def is_gradio_available():
return _is_package_available("gradio")
def is_jieba_available(): def is_jieba_available():
return _is_package_available("jieba") return _is_package_available("jieba")
@@ -49,10 +53,6 @@ def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")

View File

@@ -193,6 +193,6 @@ def llama_flash_attn_forward(
def apply_llama_patch() -> None: def apply_llama_patch() -> None:
require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1") require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3")
LlamaAttention.forward = llama_torch_attn_forward LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward LlamaFlashAttention2.forward = llama_flash_attn_forward

View File

@@ -1,38 +0,0 @@
import torch
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
def patch_mixtral_replace_moe_impl() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -52,6 +52,6 @@ def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
plt.xlabel("step") plt.xlabel("step")
plt.ylabel(key) plt.ylabel(key)
plt.legend() plt.legend()
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace(os.path.sep, "_"))) figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
plt.savefig(figure_path, format="png", dpi=100) plt.savefig(figure_path, format="png", dpi=100)
print("Figure saved at:", figure_path) print("Figure saved at:", figure_path)

View File

@@ -84,9 +84,9 @@ class DataArguments:
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
}, },
) )
cache_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to save or load the pre-processed datasets."}, metadata={"help": "Path to save or load the tokenized datasets."},
) )
def __post_init__(self): def __post_init__(self):

View File

@@ -102,10 +102,18 @@ class RLHFArguments:
default="sigmoid", default="sigmoid",
metadata={"help": "The type of DPO loss to use."}, metadata={"help": "The type of DPO loss to use."},
) )
dpo_label_smoothing: float = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
dpo_ftx: float = field( dpo_ftx: float = field(
default=0.0, default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
) )
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field( ppo_buffer_size: int = field(
default=1, default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
@@ -114,10 +122,6 @@ class RLHFArguments:
default=4, default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."}, metadata={"help": "The number of epochs to perform in a PPO optimization step."},
) )
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
)
ppo_score_norm: bool = field( ppo_score_norm: bool = field(
default=False, default=False,
metadata={"help": "Use score normalization in PPO training."}, metadata={"help": "Use score normalization in PPO training."},
@@ -168,7 +172,7 @@ class GaloreArguments:
use_galore: bool = field( use_galore: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."}, metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
) )
galore_target: str = field( galore_target: str = field(
default="all", default="all",
@@ -200,7 +204,54 @@ class GaloreArguments:
@dataclass @dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments): class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_update_ratio: float = field(
default=0.0,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": """The mode of the mask for BAdam optimizer. \
`adjacent` means that the trainable parameters are adjacent to each other, \
`scatter` means that trainable parameters are randomly choosed from the weight."""
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": """The verbosity level of BAdam optimizer. \
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
@@ -209,7 +260,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
) )
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field( stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."}, metadata={"help": "Which stage will be performed in training."},
) )
@@ -248,12 +299,18 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full": if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.") raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
if self.use_galore and self.finetuning_type == "lora": if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.") raise ValueError("Cannot use LoRA with GaLore together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"

View File

@@ -53,22 +53,34 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."}, metadata={"help": "Whether or not to use double quantization in int4 training."},
) )
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
flash_attn: bool = field( flash_attn: bool = field(
default=False, default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."}, metadata={"help": "Enable FlashAttention for faster training."},
) )
shift_attn: bool = field( shift_attn: bool = field(
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field( use_unsloth: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
) )
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field( disable_gradient_checkpointing: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."}, metadata={"help": "Whether or not to disable gradient checkpointing."},
@@ -121,6 +133,10 @@ class ModelArguments:
default=1, default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}, metadata={"help": "The file shard size (in GB) of the exported model."},
) )
export_device: str = field(
default="cpu",
metadata={"help": "The device used in model export."},
)
export_quantization_bit: Optional[int] = field( export_quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the exported model."}, metadata={"help": "The number of bits to quantize the exported model."},

View File

@@ -11,8 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import check_dependencies from ..extras.misc import check_dependencies, get_current_device
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@@ -75,6 +74,32 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
@@ -119,21 +144,24 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")
if (
finetuning_args.stage == "ppo"
and training_args.report_to
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.max_steps == -1 and data_args.streaming: if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.") raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate: if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.") raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available(): if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth") raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.use_dora: if finetuning_args.use_dora and model_args.use_unsloth:
if model_args.quantization_bit is not None: raise ValueError("Unsloth does not support DoRA.")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
if model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.pure_bf16: if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available(): if not is_torch_bf16_gpu_available():
@@ -149,13 +177,21 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
): ):
raise ValueError("Distributed training does not support layer-wise GaLore.") raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_galore and training_args.deepspeed is not None: if (
raise ValueError("GaLore is incompatible with DeepSpeed.") finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if ( if (
training_args.do_train training_args.do_train
@@ -233,6 +269,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
elif training_args.fp16: elif training_args.fp16:
model_args.compute_dtype = torch.float16 model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
@@ -274,8 +311,12 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
raise ValueError("vLLM engine does not support RoPE scaling.") raise ValueError("vLLM engine does not support RoPE scaling.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto" if model_args.export_dir is not None:
model_args.device_map = {"": torch.device(model_args.export_device)}
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
@@ -292,6 +333,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto" model_args.device_map = "auto"

View File

@@ -1,10 +1,9 @@
from .loader import load_model, load_model_and_tokenizer, load_tokenizer from .loader import load_model, load_tokenizer
from .utils import find_all_linear_modules, load_valuehead_params from .utils import find_all_linear_modules, load_valuehead_params
__all__ = [ __all__ = [
"load_model", "load_model",
"load_model_and_tokenizer",
"load_tokenizer", "load_tokenizer",
"load_valuehead_params", "load_valuehead_params",
"find_all_linear_modules", "find_all_linear_modules",

View File

@@ -32,9 +32,12 @@ def init_adapter(
logger.info("Adapter is not found at evaluation, load the base model.") logger.info("Adapter is not found at evaluation, load the base model.")
return model return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if finetuning_args.finetuning_type == "full" and is_trainable: if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16: if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float() model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable: if finetuning_args.finetuning_type == "freeze" and is_trainable:
@@ -66,6 +69,8 @@ def init_adapter(
for name, _ in model.named_modules(): for name, _ in model.named_modules():
if ".0." in name: if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0]) freeze_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
trainable_layers = [] trainable_layers = []
for module_name in finetuning_args.name_module_trainable: for module_name in finetuning_args.name_module_trainable:
@@ -79,7 +84,7 @@ def init_adapter(
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers): if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16: if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
else: else:
param.requires_grad_(False) param.requires_grad_(False)
@@ -129,9 +134,12 @@ def init_adapter(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None: if (
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES: finetuning_args.use_dora
raise ValueError("DoRA is not compatible with PTQ-quantized models.") and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
@@ -139,24 +147,28 @@ def init_adapter(
"lora_alpha": finetuning_args.lora_alpha, "lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout, "lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora, "use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
} }
if model_args.use_unsloth: if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else: else:
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
inference_mode=False, inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora, use_dora=finetuning_args.use_dora,
**peft_kwargs, **peft_kwargs,
) )
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
if not finetuning_args.pure_bf16: if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)

View File

@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, Any, Dict, Tuple from typing import TYPE_CHECKING, Any, Dict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.constants import MOD_SUPPORTED_MODELS
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
@@ -20,6 +21,7 @@ logger = get_logger(__name__)
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return { return {
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
@@ -34,16 +36,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
try_download_model_from_ms(model_args)
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
return tokenizer return tokenizer
@@ -74,6 +83,8 @@ def load_model(
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
"device_map": {"": get_current_device()}, "device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None), "rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
} }
try: try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
@@ -86,7 +97,24 @@ def load_model(
logger.warning("Unsloth does not support loading adapters.") logger.warning("Unsloth does not support loading adapters.")
if model is None: if model is None:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs) init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load":
from MoD import AutoMoDModelForCausalLM
model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
patch_model(model, tokenizer, model_args, is_trainable) patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer) register_autoclass(config, model, tokenizer)
@@ -94,7 +122,7 @@ def load_model(
model = init_adapter(model, model_args, finetuning_args, is_trainable) model = init_adapter(model, model_args, finetuning_args, is_trainable)
if add_valuehead: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model) patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
@@ -110,9 +138,6 @@ def load_model(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
model.eval() model.eval()
for param in model.parameters():
if param.device.type == "cuda":
param.data = param.data.to(model_args.compute_dtype)
else: else:
model.train() model.train()
@@ -134,17 +159,3 @@ def load_model(
) )
return model return model
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
"""
tokenizer = load_tokenizer(model_args)
model = load_model(tokenizer, model_args, finetuning_args, is_trainable, add_valuehead)
return model, tokenizer

View File

@@ -17,8 +17,7 @@ from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
from .utils import QuantizationMethod
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -32,47 +31,6 @@ logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r""" r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
@@ -103,9 +61,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
return samples return samples
def _configure_attn_implementation( def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
) -> None:
if model_args.flash_attn: if model_args.flash_attn:
if not is_flash_attn2_available(): if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.") logger.warning("FlashAttention2 is not installed.")
@@ -115,9 +71,9 @@ def _configure_attn_implementation(
if getattr(config, "model_type", None) == "internlm2": # special case for custom models if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", "flash_attention_2") setattr(config, "attn_implementation", "flash_attention_2")
else: else:
init_kwargs["attn_implementation"] = "flash_attention_2" setattr(config, "_attn_implementation", "flash_attention_2")
else: else:
init_kwargs["attn_implementation"] = "eager" setattr(config, "_attn_implementation", "eager")
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
@@ -173,15 +129,22 @@ def _configure_quantization(
""" """
if getattr(config, "quantization_config", None): # ptq if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM: if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
@@ -208,11 +171,6 @@ def _configure_quantization(
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
@@ -227,10 +185,66 @@ def _configure_quantization(
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
) )
init_kwargs["device_map"] = {"": get_current_device()} if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def _prepare_model_for_training( def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None: ) -> None:
@@ -253,20 +267,16 @@ def _prepare_model_for_training(
else: else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339 # According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
logger.info("Upcasting lm_head outputs in float32.") logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name) output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook) output_layer.register_forward_hook(_fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
@@ -284,12 +294,7 @@ def patch_config(
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if getattr(config, "model_type", None) == "qwen": _configure_attn_implementation(config, model_args)
setattr(config, "use_flash_attn", model_args.flash_attn)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(config, model_args, init_kwargs)
_configure_rope(config, model_args, is_trainable) _configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable) _configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs) _configure_quantization(config, tokenizer, model_args, init_kwargs)
@@ -298,12 +303,29 @@ def patch_config(
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.") logger.info("Using KV cache for faster generation.")
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable:
setattr(config, "output_router_logits", True)
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]: if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs: # quant models cannot use auto device map if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} init_kwargs["device_map"] = model_args.device_map
if init_kwargs["device_map"] == "auto": if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder init_kwargs["offload_folder"] = model_args.offload_folder
@@ -312,10 +334,18 @@ def patch_config(
def patch_model( def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
) -> None: ) -> None:
gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and (
(gen_config.temperature is not None and gen_config.temperature != 1.0)
or (gen_config.top_p is not None and gen_config.top_p != 1.0)
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__): if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)
if getattr(model.config, "model_type", None) == "chatglm": if is_trainable and getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer) setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
@@ -325,15 +355,15 @@ def patch_model(
if is_trainable: if is_trainable:
_prepare_model_for_training(model, model_args) _prepare_model_for_training(model, model_args)
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled(): if getattr(model.config, "model_type", None) == "mixtral":
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) add_z3_leaf_module(model, MixtralSparseMoeBlock)
if is_trainable: if getattr(model.config, "model_type", None) == "qwen2moe":
patch_mixtral_replace_moe_impl() from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
try: try:
model.add_model_tags(["llama-factory"]) model.add_model_tags(["llama-factory"])

View File

@@ -1,9 +1,13 @@
import inspect
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import cached_file from transformers.utils import cached_file
from transformers.utils.versions import require_version
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger from ..extras.logging import get_logger
@@ -28,11 +32,23 @@ class QuantizationMethod(str, Enum):
GPTQ = "gptq" GPTQ = "gptq"
AWQ = "awq" AWQ = "awq"
AQLM = "aqlm" AQLM = "aqlm"
QUANTO = "quanto"
def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, [module])
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r""" r"""
Finds all available modules to apply lora. Finds all available modules to apply lora or galore.
""" """
quantization_method = getattr(model, "quantization_method", None) quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None: if quantization_method is None:
@@ -47,6 +63,8 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
output_layer_names = ["lm_head"] output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm": if model.config.model_type == "chatglm":
output_layer_names.append("output_layer") output_layer_names.append("output_layer")
elif model.config.model_type == "internlm2":
output_layer_names.append("output")
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
@@ -84,6 +102,42 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
return module_names return module_names
def gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r""" r"""
Loads value head parameters from Hugging Face Hub or local disk. Loads value head parameters from Hugging Face Hub or local disk.

View File

@@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch import torch
@@ -8,7 +9,7 @@ from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -20,12 +21,9 @@ if TYPE_CHECKING:
class CustomDPOTrainer(DPOTrainer): class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: bool = True, disable_dropout: bool = True,
**kwargs, **kwargs,
): ):
@@ -47,10 +45,10 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model self.ref_model = ref_model
self.beta = beta self.beta = finetuning_args.dpo_beta
self.label_smoothing = 0 self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = loss_type self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = ftx_gamma self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
@@ -66,14 +64,23 @@ class CustomDPOTrainer(DPOTrainer):
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: if finetuning_args.use_badam:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.create_optimizer() self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor: def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r""" r"""
Computes supervised cross-entropy loss of given labels under the given logits. Computes supervised cross-entropy loss of given labels under the given logits.
@@ -84,18 +91,27 @@ class CustomDPOTrainer(DPOTrainer):
return -all_logps return -all_logps
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under the given logits if loss_type != IPO.
Otherwise the average log probabilities.
"""
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model( all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32) ).logits.to(torch.float32)
all_logps = self.get_batch_logps( all_logps = self.get_batch_logps(
all_logits, logits=all_logits,
batch["labels"], labels=batch_copied["labels"],
average_log_prob=False, average_log_prob=(self.loss_type == "ipo"),
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id, label_pad_token_id=self.label_pad_token_id,
) )
batch_size = batch["input_ids"].size(0) // 2 batch_size = batch["input_ids"].size(0) // 2
@@ -106,9 +122,9 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, torch.Tensor], batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train", train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r""" r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.
""" """
@@ -149,13 +165,13 @@ class CustomDPOTrainer(DPOTrainer):
reward_accuracies = (chosen_rewards > rejected_rewards).float() reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics return losses.mean(), metrics

View File

@@ -2,13 +2,12 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import get_dataset, split_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push, create_ref_model from ..utils import create_modelcard_and_push, create_ref_model
from .collator import DPODataCollatorWithPadding
from .trainer import CustomDPOTrainer from .trainer import CustomDPOTrainer
@@ -28,7 +27,8 @@ def run_dpo(
tokenizer = load_tokenizer(model_args) tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DPODataCollatorWithPadding(
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
@@ -45,13 +45,10 @@ def run_dpo(
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
finetuning_args=finetuning_args,
model=model, model=model,
ref_model=ref_model, ref_model=ref_model,
args=training_args, args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
@@ -66,7 +63,7 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@@ -0,0 +1,4 @@
from .workflow import run_orpo
__all__ = ["run_orpo"]

View File

@@ -0,0 +1,127 @@
from collections import defaultdict
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import FinetuningArguments
class CustomORPOTrainer(DPOTrainer):
def __init__(
self,
model: Union["PreTrainedModel", "torch.nn.Module"],
finetuning_args: "FinetuningArguments",
disable_dropout: bool = True,
**kwargs,
):
if disable_dropout:
disable_dropout_in_model(model)
self.finetuning_args = finetuning_args
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
self.precompute_ref_log_probs = False
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.beta = finetuning_args.orpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss.
"""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
odds_ratio_loss = -F.logsigmoid(log_odds)
return odds_ratio_loss
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
logits=all_logits,
labels=batch["labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the ORPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
sft_loss = -chosen_logps
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()
chosen_rewards = self.beta * chosen_logps.detach()
rejected_rewards = self.beta * rejected_logps.detach()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().cpu().mean()
return batch_loss, metrics

View File

@@ -0,0 +1,68 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push
from .trainer import CustomORPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments
def run_orpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer
trainer = CustomORPOTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args),
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies", "sft_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -1,25 +1,29 @@
import math import math
import os import os
import sys import sys
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import remove_dummy_checkpoint from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..utils import create_custom_optimzer, create_custom_scheduler
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from datasets import Dataset
from transformers import DataCollatorWithPadding, PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
@@ -40,10 +44,53 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"], callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
**kwargs, reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
tokenizer: "PreTrainedTokenizer",
dataset: "Dataset",
data_collator: "DataCollatorWithPadding",
): ):
PPOTrainer.__init__(self, **kwargs) backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
log_with=training_args.report_to[0] if training_args.report_to else None,
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = self.create_optimizer(model, training_args, finetuning_args)
scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)
PPOTrainer.__init__(
self,
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
lr_scheduler=scheduler,
)
self.args = training_args self.args = training_args
self.model_args = model_args self.model_args = model_args
@@ -78,6 +125,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else: else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r""" r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
@@ -205,6 +257,44 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
) )
def create_optimizer(
self,
model: "AutoModelForCausalLMWithValueHead",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
if optimizer is None:
decay_params, nodecay_params = [], []
decay_param_names = self.get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
optimizer = optim_class(param_groups, **optim_kwargs)
return optimizer
def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(training_args, num_training_steps, optimizer)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
return lr_scheduler
@torch.no_grad() @torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r""" r"""
@@ -269,7 +359,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses) batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)

View File

@@ -1,19 +1,15 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_custom_optimzer, create_ref_model, create_reward_model from ..utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer from .trainer import CustomPPOTrainer
@@ -42,45 +38,6 @@ def run_ppo(
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
reward_model = create_reward_model(model, model_args, finetuning_args) reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = create_custom_optimzer(model, training_args, finetuning_args, num_training_steps)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
# Initialize our Trainer # Initialize our Trainer
ppo_trainer = CustomPPOTrainer( ppo_trainer = CustomPPOTrainer(
model_args=model_args, model_args=model_args,
@@ -88,15 +45,12 @@ def run_ppo(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
generating_args=generating_args, generating_args=generating_args,
callbacks=callbacks + [FixValueHeadModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model, model=model,
reward_model=reward_model,
ref_model=ref_model, ref_model=ref_model,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset=dataset, dataset=dataset,
data_collator=data_collator, data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
) )
# Training # Training

View File

@@ -1,12 +1,15 @@
from typing import TYPE_CHECKING from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..utils import create_custom_optimzer from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
@@ -21,10 +24,18 @@ class CustomTrainer(Trainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.create_optimizer() self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)

View File

@@ -1,29 +0,0 @@
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
}
for key in ("chosen_ids", "rejected_ids")
for feature in features
]
return super().__call__(features)

View File

@@ -1,12 +1,13 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..utils import create_custom_optimzer from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -28,13 +29,21 @@ class PairwiseTrainer(Trainer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.create_optimizer() self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False

View File

@@ -2,13 +2,12 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import get_dataset, split_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.callbacks import FixValueHeadModelCallback from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push from ..utils import create_modelcard_and_push
from .collator import PairwiseDataCollatorWithPadding
from .metric import compute_accuracy from .metric import compute_accuracy
from .trainer import PairwiseTrainer from .trainer import PairwiseTrainer
@@ -56,7 +55,7 @@ def run_rm(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import numpy as np import numpy as np
from transformers.utils.versions import require_version
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
@@ -33,10 +32,6 @@ class ComputeMetrics:
r""" r"""
Uses the model predictions to compute metrics. Uses the model predictions to compute metrics.
""" """
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
preds, labels = eval_preds preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}

View File

@@ -1,5 +1,6 @@
import json import json
import os import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
@@ -8,7 +9,7 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..utils import create_custom_optimzer from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -28,13 +29,21 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.create_optimizer() self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def prediction_step( def prediction_step(
self, self,

View File

@@ -7,8 +7,9 @@ from ..data import get_template_and_fix_tokenizer
from ..extras.callbacks import LogCallback from ..extras.callbacks import LogCallback
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args from ..hparams import get_infer_args, get_train_args
from ..model import load_model_and_tokenizer from ..model import load_model, load_tokenizer
from .dpo import run_dpo from .dpo import run_dpo
from .orpo import run_orpo
from .ppo import run_ppo from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
@@ -36,6 +37,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo": elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
else: else:
raise ValueError("Unknown task.") raise ValueError("Unknown task.")
@@ -49,8 +52,9 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.") raise ValueError("Please merge adapters before quantizing the model.")
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) tokenizer = load_tokenizer(model_args)
get_template_and_fix_tokenizer(tokenizer, data_args.template) get_template_and_fix_tokenizer(tokenizer, data_args.template)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None: if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
raise ValueError("Cannot merge adapters to a quantized model.") raise ValueError("Cannot merge adapters to a quantized model.")
@@ -61,16 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16) output_dtype = getattr(model.config, "torch_dtype", torch.float16)
setattr(model.config, "torch_dtype", output_dtype) setattr(model.config, "torch_dtype", output_dtype)
for param in model.parameters(): model = model.to(output_dtype)
param.data = param.data.to(output_dtype)
gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and (
(gen_config.temperature is not None and gen_config.temperature != 1.0)
or (gen_config.top_p is not None and gen_config.top_p != 1.0)
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
):
gen_config.do_sample = True
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,

View File

@@ -5,12 +5,11 @@ from transformers import Trainer
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model_and_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available(): if is_galore_available():
@@ -29,7 +28,13 @@ logger = get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer): class DummyOptimizer(torch.optim.Optimizer):
def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None: r"""
A dummy optimizer used for the GaLore algorithm.
"""
def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None:
dummy_tensor = torch.randn(1, 1) dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr}) super().__init__([dummy_tensor], {"lr": lr})
@@ -51,9 +56,11 @@ def create_modelcard_and_push(
kwargs = { kwargs = {
"tasks": "text-generation", "tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory", finetuning_args.finetuning_type], "tags": ["llama-factory", finetuning_args.finetuning_type],
} }
if data_args.dataset is not None:
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
if not training_args.do_train: if not training_args.do_train:
pass pass
elif training_args.push_to_hub: elif training_args.push_to_hub:
@@ -64,7 +71,7 @@ def create_modelcard_and_push(
def create_ref_model( def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: ) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r""" r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported. Creates reference model for PPO/DPO training. Evaluation mode is not supported.
@@ -81,16 +88,18 @@ def create_ref_model(
) )
ref_model_args = ModelArguments(**ref_model_args_dict) ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora") ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer( tokenizer = load_tokenizer(ref_model_args)
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
logger.info("Created reference model from {}".format(finetuning_args.ref_model)) logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else: else:
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
ref_model = None ref_model = None
else: else:
ref_model, _ = load_model_and_tokenizer( tokenizer = load_tokenizer(model_args)
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead ref_model = load_model(
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
logger.info("Created reference model from the model itself.") logger.info("Created reference model from the model itself.")
@@ -99,7 +108,7 @@ def create_ref_model(
def create_reward_model( def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments" model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead": ) -> Optional["AutoModelForCausalLMWithValueHead"]:
r""" r"""
Creates reward model for PPO training. Creates reward model for PPO training.
""" """
@@ -135,8 +144,9 @@ def create_reward_model(
) )
reward_model_args = ModelArguments(**reward_model_args_dict) reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora") reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer( tokenizer = load_tokenizer(reward_model_args)
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
) )
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
@@ -156,7 +166,6 @@ def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
max_steps: int,
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model) galore_targets = find_all_linear_modules(model)
@@ -207,37 +216,27 @@ def _create_galore_optimizer(
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params: for param in nodecay_params:
param_groups = [dict(params=[param])] param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params: for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params: for param in galore_params: # galore params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)] param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {} def optimizer_hook(param: "torch.nn.Parameter"):
for param in trainable_params:
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2,
num_training_steps=max_steps * 2,
)
def optimizer_hook(param: "torch.Tensor"):
if param.grad is not None: if param.grad is not None:
optimizer_dict[param].step() optimizer_dict[param].step()
optimizer_dict[param].zero_grad() optimizer_dict[param].zero_grad()
scheduler_dict[param].step()
for param in trainable_params: for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook) param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else: else:
param_groups = [ param_groups = [
dict(params=nodecay_params), dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay), dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
] ]
@@ -252,11 +251,9 @@ def _create_loraplus_optimizer(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
if finetuning_args.finetuning_type != "lora": default_lr = training_args.learning_rate
raise ValueError("You should use LoRA tuning to activate LoRA+.")
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
decay_args = {"weight_decay": training_args.weight_decay} embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model) decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = { param_dict: Dict[str, List["torch.nn.Parameter"]] = {
@@ -279,24 +276,110 @@ def _create_loraplus_optimizer(
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [ param_groups = [
dict(params=param_dict["lora_a"], **decay_args), dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args), dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr), dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args), dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
return optimizer return optimizer
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_block_every,
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
named_parameters_list=list(model.named_parameters()),
update_ratio=finetuning_args.badam_update_ratio,
mask_mode=finetuning_args.badam_mask_mode,
verbose=finetuning_args.badam_verbose,
include_embedding=False,
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
return optimizer
def create_custom_optimzer( def create_custom_optimzer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
max_steps: int,
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore: if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args, max_steps) return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None: if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args) return _create_loraplus_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
def scheduler_hook(param: "torch.nn.Parameter"):
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)

View File

@@ -1,13 +1,11 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir from .common import get_save_dir
from .locales import ALERTS from .locales import ALERTS
@@ -17,6 +15,10 @@ if TYPE_CHECKING:
from .manager import Manager from .manager import Manager
if is_gradio_available():
import gradio as gr
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
@@ -35,8 +37,8 @@ class WebChatModel(ChatModel):
def loaded(self) -> bool: def loaded(self) -> bool:
return self.engine is not None return self.engine is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def load_model(self, data) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem_by_name(name)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang") lang = get("top.lang")
error = "" error = ""
if self.loaded: if self.loaded:
@@ -79,8 +81,8 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def unload_model(self, data) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_name("top.lang")] lang = data[self.manager.get_elem_by_id("top.lang")]
if self.demo_mode: if self.demo_mode:
gr.Warning(ALERTS["err_demo"][lang]) gr.Warning(ALERTS["err_demo"][lang])
@@ -92,23 +94,29 @@ class WebChatModel(ChatModel):
torch_gc() torch_gc()
yield ALERTS["info_unloaded"][lang] yield ALERTS["info_unloaded"][lang]
def predict( def append(
self, self,
chatbot: List[Tuple[str, str]], chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
role: str, role: str,
query: str, query: str,
messages: Sequence[Tuple[str, str]], ) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
def stream(
self,
chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
system: str, system: str,
tools: str, tools: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float, temperature: float,
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]: ) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
chatbot.append([query, ""]) chatbot[-1][1] = ""
query_messages = messages + [{"role": role, "content": query}]
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
if tools: if tools:
@@ -120,18 +128,11 @@ class WebChatModel(ChatModel):
name, arguments = result name, arguments = result
arguments = json.loads(arguments) arguments = json.loads(arguments)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}] output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
bot_text = "```json\n" + tool_call + "\n```" bot_text = "```json\n" + tool_call + "\n```"
else: else:
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}] output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result bot_text = result
chatbot[-1] = [query, self.postprocess(bot_text)] chatbot[-1][1] = bot_text
yield chatbot, output_messages yield chatbot, output_messages
def postprocess(self, response: str) -> str:
blocks = response.split("```")
for i, block in enumerate(blocks):
if i % 2 == 0:
blocks[i] = block.replace("<", "&lt;").replace(">", "&gt;")
return "```".join(blocks)

View File

@@ -3,7 +3,6 @@ import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import ( from ..extras.constants import (
@@ -11,15 +10,22 @@ from ..extras.constants import (
DEFAULT_MODULE, DEFAULT_MODULE,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
PEFT_METHODS, PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS, SUPPORTED_MODELS,
TRAINING_STAGES, TRAINING_STAGES,
DownloadSource, DownloadSource,
) )
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data" DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves" DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config" USER_CONFIG = "user.config"
@@ -33,6 +39,10 @@ def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def get_save_path(config_path: str) -> os.PathLike:
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
def load_config() -> Dict[str, Any]: def load_config() -> Dict[str, Any]:
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
@@ -52,6 +62,22 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
json.dump(user_config, f, indent=2, ensure_ascii=False) json.dump(user_config, f, indent=2, ensure_ascii=False)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
try:
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)
return str(get_save_path(config_path))
def get_model_path(model_name: str) -> str: def get_model_path(model_name: str) -> str:
user_config = load_config() user_config = load_config()
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
@@ -79,9 +105,9 @@ def get_template(model_name: str) -> str:
return "default" return "default"
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]: def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS: if finetuning_type not in PEFT_METHODS:
return gr.update(value=[], choices=[], interactive=False) return gr.Dropdown(value=[], choices=[], interactive=False)
adapters = [] adapters = []
if model_name and finetuning_type == "lora": if model_name and finetuning_type == "lora":
@@ -92,7 +118,7 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
): ):
adapters.append(adapter) adapters.append(adapter)
return gr.update(value=[], choices=adapters, interactive=True) return gr.Dropdown(value=[], choices=adapters, interactive=True)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
@@ -104,12 +130,12 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {} return {}
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]: def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.update(value=[], choices=datasets) return gr.Dropdown(value=[], choices=datasets)
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]: def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
return gr.update(value=(TRAINING_STAGES[training_stage] == "pt")) return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))

View File

@@ -1,13 +1,15 @@
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
from ...data import Role from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema from ..utils import check_json_schema
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component from gradio.components import Component
from ..engine import Engine from ..engine import Engine
@@ -15,9 +17,9 @@ if TYPE_CHECKING:
def create_chat_box( def create_chat_box(
engine: "Engine", visible: bool = False engine: "Engine", visible: bool = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: ) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box: with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot(show_copy_button=True)
messages = gr.State([]) messages = gr.State([])
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
@@ -33,16 +35,18 @@ def create_chat_box(
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
clear_btn = gr.Button() clear_btn = gr.Button()
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")]) tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
submit_btn.click( submit_btn.click(
engine.chatter.predict, engine.chatter.append,
[chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature], [chatbot, messages, role, query],
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, max_new_tokens, top_p, temperature],
[chatbot, messages], [chatbot, messages],
show_progress=True, )
).then(lambda: gr.update(value=""), outputs=[query]) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
return ( return (
chat_box, chat_box,

View File

@@ -1,10 +1,13 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import gradio as gr
from ...extras.constants import DATA_CONFIG from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -22,36 +25,46 @@ def next_page(page_index: int, total_num: int) -> int:
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception: except Exception:
return gr.update(interactive=False) return gr.Button(interactive=False)
if ( if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
len(dataset) > 0 return gr.Button(interactive=False)
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
): if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
return gr.update(interactive=True) return gr.Button(interactive=True)
else: else:
return gr.update(interactive=False) return gr.Button(interactive=False)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]: def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
return [json.loads(line) for line in f]
else:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_file: str = dataset_info[dataset[0]]["file_name"] data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: if os.path.isfile(data_path):
if data_file.endswith(".json"): data = _load_data_file(data_path)
data = json.load(f) else:
elif data_file.endswith(".jsonl"): data = []
data = [json.loads(line) for line in f] for file_name in os.listdir(data_path):
else: data.extend(_load_data_file(os.path.join(data_path, file_name)))
data = [line for line in f] # noqa: C416
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True) return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
@@ -67,7 +80,7 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
close_btn = gr.Button() close_btn = gr.Button()
with gr.Row(): with gr.Row():
preview_samples = gr.JSON(interactive=False) preview_samples = gr.JSON()
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then( dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
lambda: 0, outputs=[page_index], queue=False lambda: 0, outputs=[page_index], queue=False
@@ -81,7 +94,7 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then( next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
) )
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
return dict( return dict(
data_preview_btn=data_preview_btn, data_preview_btn=data_preview_btn,
preview_count=preview_count, preview_count=preview_count,

View File

@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box from .data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -21,8 +24,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset) preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
input_elems.update({dataset_dir, dataset}) input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
@@ -46,14 +47,14 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button(variant="primary")
stop_btn = gr.Button() stop_btn = gr.Button(variant="stop")
with gr.Row(): with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False) process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Row():
output_box = gr.Markdown() output_box = gr.Markdown()
output_elems = [output_box, process_bar] output_elems = [output_box, process_bar]
@@ -68,9 +69,11 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
) )
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems) cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
start_btn.click(engine.runner.run_eval, input_elems, output_elems) start_btn.click(engine.runner.run_eval, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False) stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
return elem_dict return elem_dict

View File

@@ -1,12 +1,15 @@
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr from ...extras.packages import is_gradio_available
from ...train import export_model from ...train import export_model
from ..common import get_save_dir from ..common import get_save_dir
from ..locales import ALERTS from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -74,7 +77,7 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100) max_shard_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none") export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_legacy_format = gr.Checkbox() export_legacy_format = gr.Checkbox()
@@ -89,12 +92,12 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_btn.click( export_btn.click(
save_model, save_model,
[ [
engine.manager.get_elem_by_name("top.lang"), engine.manager.get_elem_by_id("top.lang"),
engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_name("top.model_path"), engine.manager.get_elem_by_id("top.model_path"),
engine.manager.get_elem_by_name("top.adapter_path"), engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_name("top.finetuning_type"), engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_name("top.template"), engine.manager.get_elem_by_id("top.template"),
max_shard_size, max_shard_size,
export_quantization_bit, export_quantization_bit,
export_quantization_dataset, export_quantization_dataset,

View File

@@ -1,10 +1,13 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr from ...extras.packages import is_gradio_available
from .chatbot import create_chat_box from .chatbot import create_chat_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -25,15 +28,15 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.update({infer_backend}) input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems)) elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
) )
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, history] lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]) ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
return elem_dict return elem_dict

View File

@@ -1,18 +1,21 @@
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, list_adapters, save_config from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize from ..utils import can_quantize
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]: def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
@@ -25,7 +28,7 @@ def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5) adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(open=False) as advanced_tab:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
template = gr.Dropdown(choices=list(templates.keys()), value="default") template = gr.Dropdown(choices=list(templates.keys()), value="default")
@@ -44,7 +47,7 @@ def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False) refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return lang, dict( return dict(
lang=lang, lang=lang,
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,

View File

@@ -1,12 +1,15 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box from ..components.data import create_preview_box
from ..utils import gen_plot
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -21,14 +24,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
training_stage = gr.Dropdown( training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2 choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
) )
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset) preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset}) input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
@@ -68,7 +69,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
) )
with gr.Accordion(label="Extra config", open=False) as extra_tab: with gr.Accordion(open=False) as extra_tab:
with gr.Row(): with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
@@ -77,11 +78,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim = gr.Textbox(value="adamw_torch") optim = gr.Textbox(value="adamw_torch")
with gr.Row(): with gr.Row():
resize_vocab = gr.Checkbox() with gr.Column():
packing = gr.Checkbox() resize_vocab = gr.Checkbox()
upcast_layernorm = gr.Checkbox() packing = gr.Checkbox()
use_llama_pro = gr.Checkbox()
shift_attn = gr.Checkbox() with gr.Column():
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
with gr.Column():
shift_attn = gr.Checkbox()
report_to = gr.Checkbox()
input_elems.update( input_elems.update(
{ {
@@ -95,6 +102,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm, upcast_layernorm,
use_llama_pro, use_llama_pro,
shift_attn, shift_attn,
report_to,
} }
) )
elem_dict.update( elem_dict.update(
@@ -110,13 +118,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm=upcast_layernorm, upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro, use_llama_pro=use_llama_pro,
shift_attn=shift_attn, shift_attn=shift_attn,
report_to=report_to,
) )
) )
with gr.Accordion(label="Freeze config", open=False) as freeze_tab: with gr.Accordion(open=False) as freeze_tab:
with gr.Row(): with gr.Row():
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2) num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
name_module_trainable = gr.Textbox(value="all", scale=3) name_module_trainable = gr.Textbox(value="all")
input_elems.update({num_layer_trainable, name_module_trainable}) input_elems.update({num_layer_trainable, name_module_trainable})
elem_dict.update( elem_dict.update(
@@ -125,21 +134,34 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
) )
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(open=False) as lora_tab:
with gr.Row(): with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1) lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
lora_target = gr.Textbox(scale=2) loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
create_new_adapter = gr.Checkbox()
with gr.Row(): with gr.Row():
use_rslora = gr.Checkbox(scale=1) with gr.Column(scale=1):
use_dora = gr.Checkbox(scale=1) use_rslora = gr.Checkbox()
create_new_adapter = gr.Checkbox(scale=1) use_dora = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2) additional_target = gr.Textbox(scale=2)
input_elems.update( input_elems.update(
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target} {
lora_rank,
lora_alpha,
lora_dropout,
loraplus_lr_ratio,
create_new_adapter,
use_rslora,
use_dora,
lora_target,
additional_target,
}
) )
elem_dict.update( elem_dict.update(
dict( dict(
@@ -147,37 +169,34 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
lora_rank=lora_rank, lora_rank=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target, loraplus_lr_ratio=loraplus_lr_ratio,
create_new_adapter=create_new_adapter,
use_rslora=use_rslora, use_rslora=use_rslora,
use_dora=use_dora, use_dora=use_dora,
create_new_adapter=create_new_adapter, lora_target=lora_target,
additional_target=additional_target, additional_target=additional_target,
) )
) )
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1) dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2) orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then( input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
list_adapters, elem_dict.update(
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")], dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
[reward_model], )
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
input_elems.update({dpo_beta, dpo_ftx, reward_model}) with gr.Accordion(open=False) as galore_tab:
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
with gr.Row(): with gr.Row():
use_galore = gr.Checkbox(scale=1) use_galore = gr.Checkbox()
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2) galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2) galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2) galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
galore_target = gr.Textbox(value="mlp,attn", scale=3) galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target}) input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update( elem_dict.update(
@@ -193,38 +212,36 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
start_btn = gr.Button() arg_save_btn = gr.Button()
stop_btn = gr.Button() arg_load_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Row(): with gr.Row():
output_dir = gr.Textbox() output_dir = gr.Textbox()
config_path = gr.Textbox()
with gr.Row(): with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False) resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False) process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Row():
output_box = gr.Markdown() output_box = gr.Markdown()
with gr.Column(scale=1): with gr.Column(scale=1):
loss_viewer = gr.Plot() loss_viewer = gr.Plot()
input_elems.add(output_dir)
output_elems = [output_box, process_bar]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update( elem_dict.update(
dict( dict(
cmd_preview_btn=cmd_preview_btn, cmd_preview_btn=cmd_preview_btn,
arg_save_btn=arg_save_btn,
arg_load_btn=arg_load_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_dir=output_dir, output_dir=output_dir,
config_path=config_path,
resume_btn=resume_btn, resume_btn=resume_btn,
process_bar=process_bar, process_bar=process_bar,
output_box=output_box, output_box=output_box,
@@ -232,15 +249,27 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
) )
output_box.change( input_elems.update({output_dir, config_path})
gen_plot, output_elems = [output_box, process_bar, loss_viewer]
[
engine.manager.get_elem_by_name("top.model_name"), cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
engine.manager.get_elem_by_name("top.finetuning_type"), arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
output_dir, arg_load_btn.click(
], engine.runner.load_args,
loss_viewer, [engine.manager.get_elem_by_id("top.lang"), config_path],
queue=False, list(input_elems) + [output_box],
concurrency_limit=None,
) )
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
return elem_dict return elem_dict

Some files were not shown because too many files have changed in this diff Show More