215 Commits

Author SHA1 Message Date
hiyouga
3cef844079 fix setup
Former-commit-id: 7d3e7db46a5f8672dd57fa5fcc03822e175047f9
2024-04-28 03:49:13 +08:00
hiyouga
4dcd47100d fix llava rlhf
Former-commit-id: f6863cbbcbf960d6481296c6cae3e40fd70e4e14
2024-04-28 03:01:49 +08:00
hiyouga
a412b4ed4a add models to 0.7.0
Former-commit-id: 436d3754452f839c617839ab3bbaacc4a8908e19
2024-04-28 01:50:30 +08:00
hiyouga
544a6259b6 update readme
Former-commit-id: c9190fe36f511c3a5149d45c85a10b02a57fa88a
2024-04-26 23:39:19 +08:00
hiyouga
c501f377dd release v0.7.0
Former-commit-id: 45bb89cb4d26a6b3fb5360bc90ab950738fe4920
2024-04-26 23:18:00 +08:00
hiyouga
cb8b8f40cd update readme
Former-commit-id: f3d4b46338d4d484b205d0651a1fa7b2e77a1654
2024-04-26 20:09:14 +08:00
hiyouga
70bed8ad8f support Qwen1.5 110B
Former-commit-id: d6e5ecaf4109127bab24e39a0696076bceb0b37c
2024-04-26 19:59:22 +08:00
hiyouga
51f776ae2a fix llava qlora
Former-commit-id: 01c5a669f6fe598aac1758a700a7607da37db1bc
2024-04-26 18:00:23 +08:00
hiyouga
697bc20941 add llava to llamaboard
Former-commit-id: deaaff0a9de0eef9691991c99cd797461b1165cc
2024-04-26 06:41:35 +08:00
hiyouga
1480e3a88f update readme
Former-commit-id: df1155245d3f71ba4f3361d43aa662ab3b024de8
2024-04-26 05:49:26 +08:00
hoshi-hiyouga
19029d5b0f Merge pull request #3454 from hiyouga/mllm
Support fine-tuning LLaVA-1.5 MLLM @BUAADreamer 

Former-commit-id: c4195d1e26349795f7aad5c10a8a9e2abb7b64a3
2024-04-26 05:46:29 +08:00
hiyouga
7773ac0ead update readme
Former-commit-id: 41728fd74de7bec0cc6135aef9dfa3ae9fe7af73
2024-04-26 05:44:30 +08:00
hiyouga
23b881bff1 support mllm hf inference
Former-commit-id: 2c7c01282acd7ddabbb17ce3246b8dae4bc4b8cf
2024-04-26 05:34:58 +08:00
hoshi-hiyouga
10a6c395bb Merge pull request #3450 from BUAADreamer/mllm
Add Multimodal LLM Finetuning

Former-commit-id: 7cacbcfdf7391080ef43eb2b2c79a5237e6120e8
2024-04-26 05:30:30 +08:00
hoshi-hiyouga
f9a7732a1f Update preprocess.py
Former-commit-id: 0e376eab23d38b8fca05f054f3cde308756ee3b1
2024-04-26 04:10:28 +08:00
hoshi-hiyouga
c37582af02 Update aligner.py
Former-commit-id: 855489074c469f47572153df0fa1e251b187b232
2024-04-26 03:48:34 +08:00
hoshi-hiyouga
ece67f8c7f Update parser.py
Former-commit-id: 4df75e8a9a391565cc3eec69bc0ebf5d5192de61
2024-04-26 03:35:39 +08:00
hoshi-hiyouga
e1838e76fe Update loader.py
Former-commit-id: 6a5f2e2ab7304113ff71cb77aafff6a1f74831f8
2024-04-26 03:33:07 +08:00
hoshi-hiyouga
2eede9ffd6 Update workflow.py
Former-commit-id: 5b8b5b975716d539ae2fae8536f79e106aa0b566
2024-04-26 03:29:12 +08:00
hoshi-hiyouga
a6f6b406b3 Update loader.py
Former-commit-id: 72d4817a15f6916706828ea2a61d808183c23773
2024-04-26 03:22:40 +08:00
hoshi-hiyouga
279439abbe update hparam name
Former-commit-id: 9941adfbf06db37f8ba32c4555f6e58e27188aaf
2024-04-26 02:49:39 +08:00
hoshi-hiyouga
13117b69d7 delete llava template (use vicuna)
Former-commit-id: 420e64970e5a0e45453041927e0366ee8beb73d5
2024-04-26 02:20:47 +08:00
BUAADreamer
5d03ac642d modify some bug
Former-commit-id: 593b7b004df74bd24361c9883401a656c08fb589
2024-04-25 22:59:46 +08:00
BUAADreamer
5062ee547e modify some style
Former-commit-id: 1291c7ee39361dd75247c67f04dcf20b472faf83
2024-04-25 22:40:53 +08:00
BUAADreamer
59817c27e3 modify some style
Former-commit-id: d578a90cefa7ec813355795bdd6ead5ee558ce26
2024-04-25 22:40:25 +08:00
BUAADreamer
759bee48d2 merge some func
Former-commit-id: 3085107c44715e4b2ca96d73b20d90c172b95219
2024-04-25 22:35:17 +08:00
BUAADreamer
514ffafc12 modify some style
Former-commit-id: 053062abc007014a7fde95c5ae9f4d859893d8ad
2024-04-25 22:04:09 +08:00
BUAADreamer
8b2a735c14 modify some style
Former-commit-id: b016e6a671a2f228f0bdd9b8d5995b4669609655
2024-04-25 21:58:18 +08:00
BUAADreamer
10d59e9e4a make dataset script
Former-commit-id: 25892f958da14976025a775febf628cd0e0a3d85
2024-04-25 21:32:01 +08:00
BUAADreamer
058ed5e607 modify style
Former-commit-id: c1f1df99e4dc3d0aadf1207b4e9a16218187fd5a
2024-04-25 21:29:50 +08:00
BUAADreamer
110c2ce2a5 modify style
Former-commit-id: 3bffc1e1b8bcc4582cebea06d35e5146163c7bec
2024-04-25 21:27:48 +08:00
BUAADreamer
c425436676 modify style
Former-commit-id: 54b713d0c4ffdfc6a7faeb14471b58bb1cd8acf5
2024-04-25 21:15:16 +08:00
BUAADreamer
266fe908e3 Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: c4bb5af69c5bbf0b1ea044cbb2b18acddc6733ac
2024-04-25 21:08:40 +08:00
BUAADreamer
dbd905438b add some
Former-commit-id: 8d035a849c4a441d457791aab073861adf69a09f
2024-04-25 21:08:32 +08:00
hoshi-hiyouga
d64c87f928 Merge pull request #3449 from hiyouga/mllm
add webui backend option

Former-commit-id: 372fcedef40b79fe8bd3932c06c720f2a03db6e6
2024-04-25 20:58:16 +08:00
hiyouga
29eebef696 add webui backend option
Former-commit-id: 3764586cb3ed64fe376d0ae420ff5690c28459e2
2024-04-25 20:49:23 +08:00
hiyouga
7bfbcb1fe3 vllm + lora support
Former-commit-id: 8cb86ba355195f5d6dcb95ee6b6b7203463a34db
2024-04-25 20:24:31 +08:00
BUAADreamer
9b210cf4b3 rm some
Former-commit-id: 2c85b4fabbebd8b51eee53f5d29184d4a6e97569
2024-04-25 20:09:43 +08:00
BUAADreamer
f74e640565 Merge branch 'hiyouga:main' into main
Former-commit-id: 131d0bcd554dedd794add7eb3d7b1201cac80e7c
2024-04-25 20:02:50 +08:00
BUAADreamer
d1d08d066a merge data part to the text stream
Former-commit-id: 80537d580119d9d5a06ab236a5284aaae2f83b5b
2024-04-25 19:58:47 +08:00
hiyouga
6be321b5da fix #3374
Former-commit-id: 0097d7968b3b570e1705caff26f42d9ed71ad974
2024-04-25 19:56:49 +08:00
BUAADreamer
3c792174db merge data part to the text stream
Former-commit-id: 7ee20286d9bcc2d5378bfd6bb02cd3648396d873
2024-04-25 19:19:59 +08:00
hiyouga
9aeb88c426 add export_device in webui #3333
Former-commit-id: 30ebd3652809d73941e0a5e4a8be11d989faf98d
2024-04-25 19:02:32 +08:00
BUAADreamer
00e2a272ef merge model part to the text stream
Former-commit-id: b6fcb832ddaed4647d6f2b926f3dfccd47f3ea84
2024-04-25 08:20:41 +08:00
BUAADreamer
5142349661 remove error
Former-commit-id: 2bcd1c7dc3595f17ae4e2c4475196cc2d03d0e75
2024-04-25 01:01:59 +08:00
BUAADreamer
0e3cc52327 remove conflicts
Former-commit-id: e5750ee202eb67cf5fc54f464548e2eb43d00900
2024-04-25 00:56:06 +08:00
BUAADreamer
6c1db2d012 remove conflicts
Former-commit-id: f8b637eb76cba7ec229e2978068805ad1cca8adb
2024-04-25 00:34:22 +08:00
BUAADreamer
12c51655ce add llava and instructblip
Former-commit-id: 142fb6f4541a1acfefe66ff2574dabde53b00c06
2024-04-25 00:22:43 +08:00
hiyouga
36be12a3b7 update tool template
Former-commit-id: c72a1981859818c257c5271d32e03c9d3c344206
2024-04-25 00:21:34 +08:00
hiyouga
21fac4c98c fix log level
Former-commit-id: 8d21302f6201b3f33c10f61f3559bd95be3363c2
2024-04-24 23:42:59 +08:00
hiyouga
83404c4fa9 support new special token #3420
Former-commit-id: f5c6a47f5193ab3a6c137580992bdcce0b31fdd5
2024-04-24 23:39:31 +08:00
hoshi-hiyouga
12f852b8d4 fix phi template
Former-commit-id: 14a1ff665eaebfc618229efbe96f09848d52faec
2024-04-24 13:55:14 +08:00
hoshi-hiyouga
a88873116a fix webchatmodel
Former-commit-id: dc6d8b5dc42c363dd180aaf90c9a2f2d0cce6725
2024-04-24 13:54:21 +08:00
hoshi-hiyouga
7cfcd69c64 fix inference in llamaboard
Former-commit-id: 5e631915157083b61e2d5a183e0c91f2d11f416e
2024-04-24 13:53:39 +08:00
hiyouga
a5eabbe933 add olmo 1.7
Former-commit-id: 86a3fb3a141d2702b15af08df36ffcf9b3d6de14
2024-04-24 05:50:50 +08:00
hiyouga
aa25716a5d add dbrx and jamba models
Former-commit-id: ce35c80b4b00152185285d6064939803d14487f0
2024-04-24 05:39:52 +08:00
hiyouga
94c8219575 fix bug
Former-commit-id: 38e164fe4aaea6f0baf121a720291ca42643ba8c
2024-04-24 05:21:18 +08:00
hiyouga
ad24a2a0c9 fix bug
Former-commit-id: 271c24d2c82d645fa9072e6de94ca38f20411537
2024-04-24 05:10:07 +08:00
hiyouga
c05027d14a remove redundant code
Former-commit-id: 4a7a7ad2bcdc493458084f5f3d384239228b7d5a
2024-04-24 05:02:18 +08:00
hiyouga
5420905a2e support unsloth generate
Former-commit-id: 0ef1ad9f505dba71db9342f524cc3a7565e5e09e
2024-04-24 04:46:53 +08:00
hiyouga
03f2e3284a refactor patcher
Former-commit-id: 263cfe1294f5c3188f5e8d65791f35ee0d87315a
2024-04-24 03:02:23 +08:00
hiyouga
d2bb1b3a6b reenable sdpa and fast tok by default
Former-commit-id: 9e00902dbedc71d55743d1bf237843506a557891
2024-04-24 02:18:44 +08:00
hiyouga
35c4a2c212 fix #3347 #3387
Former-commit-id: c253c18185a29b59190f3e0ed236c2bb4c788085
2024-04-24 01:30:16 +08:00
hiyouga
1e4010a1fb support phi-3
Former-commit-id: 7e8ffa9beee3893e051ceeade443bd56c4a07b1c
2024-04-24 00:28:53 +08:00
BUAADreamer
1451297c78 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: 67800c565b086f362b8cf131b0c9babaa7a7ebc7
2024-04-23 19:22:42 +08:00
BUAADreamer
0b99b13786 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: b78b5f290aa38a7454e101ee9703fb6fac5064ac
2024-04-23 18:47:03 +08:00
BUAADreamer
f5edbf2b49 Merge branch 'hiyouga:main' into main
Former-commit-id: 6287d1b789c631205c1033adf036e28deaef4167
2024-04-23 18:46:12 +08:00
BUAADreamer
ab6dc0ea30 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: a730f89a972f1a9d37c718c716f199cb8d4903b2
2024-04-23 18:45:43 +08:00
hiyouga
79d34ce0f3 update examples
Former-commit-id: 8bf55682cdfbbdca0f01073eac0084c20a6a09d1
2024-04-23 18:29:46 +08:00
hiyouga
1d2e372a8e update readme
Former-commit-id: d4eaee262a64e716ce475dc4eb18d8d9697d8dd8
2024-04-22 17:09:17 +08:00
hiyouga
f6a53d83c8 update readme
Former-commit-id: 3eab580703ee01a0d2d75e7f01df5165af551386
2024-04-22 00:51:35 +08:00
hiyouga
4ec56dd958 update readme
Former-commit-id: fdca136309709e43d75a831252b9375a5a99635a
2024-04-22 00:42:25 +08:00
hiyouga
ba06eb65ca update readme and examples
Former-commit-id: 27dd9bf201c24f7804811398bc2758966ec78432
2024-04-22 00:37:32 +08:00
hiyouga
be716972fe remove extras
Former-commit-id: d67e972f8c3d5273e589c8c85c0a1620f59785c5
2024-04-22 00:35:41 +08:00
hiyouga
719585a128 update readme
Former-commit-id: 3a8c17907c71f46b1b37501e2afdc99ad89fb4bc
2024-04-22 00:21:01 +08:00
hiyouga
348f29aa50 set dev version
Former-commit-id: b9557887d7506ff57b2b2bf490092aac4e4becf0
2024-04-21 23:14:30 +08:00
hiyouga
c8fe3f544b release v0.6.3
Former-commit-id: 947572af8de201669598f54735f35b50bb719d71
2024-04-21 23:13:23 +08:00
hiyouga
0f1ad7140f fix #3366
Former-commit-id: dc20237455c36de44f8922539d7dfadd8bedb12f
2024-04-21 21:34:25 +08:00
hiyouga
233e167f68 fix optimizers
Former-commit-id: f811eee2fa12a89a55a9c5d3a05a1521b4347727
2024-04-21 20:40:54 +08:00
hiyouga
1d341dcd83 fix #3365
Former-commit-id: 415ce41e8fa887e980e5bd575c8e95bd4076b90b
2024-04-21 19:20:18 +08:00
hiyouga
d16561e7a4 fix bug in galore optimizer
Former-commit-id: c05ac23261a5a8ba893c2918a43dc7777307407b
2024-04-21 18:53:22 +08:00
hiyouga
f8e219dc81 fix mod stuff
Former-commit-id: cf3988226e6398c67bb2955578e436fc505aa5c5
2024-04-21 18:11:10 +08:00
hoshi-hiyouga
3365cc8cf0 Merge pull request #3338 from astramind-ai/main
Adding Mixture of Depth

Former-commit-id: 4da2ece53353b63e672ff529d6beba41ff710c14
2024-04-21 18:05:52 +08:00
hoshi-hiyouga
3a5e68b7d9 fix #3348
Former-commit-id: aa5e921c00f60074eceb2f9d4d8837cc713edba6
2024-04-20 10:34:09 +08:00
hiyouga
0cb596fee1 add dpo mix dataset
Former-commit-id: 6def3f8bfa51b2d9d73af112352ce07db972e4c9
2024-04-20 01:31:38 +08:00
hiyouga
b3b5b530d1 fix #3352
Former-commit-id: f315f8e8ec916b82bac94a159e55839ff155c6b5
2024-04-19 22:40:01 +08:00
hiyouga
9225c15c88 fix llama3 template
Former-commit-id: 20e95250168fbe081c779b2e1ff23f5df3ce02f7
2024-04-19 15:46:51 +08:00
Marco
abd9fed445 fix small typo
Former-commit-id: 5638a03cd0cf8119ff366b3b3e303b5a2351b065
2024-04-18 20:33:29 +02:00
Marco
44cda2eece Added Mixture of Depths
Former-commit-id: 75dd98b9abc847e22cb263c17ebcd2ca5dd98345
2024-04-18 20:31:24 +02:00
hoshi-hiyouga
8397808d1d support llama3
Former-commit-id: c1eabb751a5fd73b710714451b146732e0ed4558
2024-04-19 01:13:50 +08:00
hiyouga
9e1bd6420d fix #3324
Former-commit-id: 5e710c4ac331f3400534d33b2646c4108c898d98
2024-04-18 15:34:45 +08:00
hiyouga
619264c854 tiny fix
Former-commit-id: 86399ca8c06273c42c2b184664ae25d3405b3bf6
2024-04-18 00:22:17 +08:00
hiyouga
1ebac62e3d update readme
Former-commit-id: a49112a74339ba77bfec53f7870e821fe148db2c
2024-04-17 23:40:49 +08:00
hiyouga
ce9bdb3509 add mixtral 8x22B models
Former-commit-id: eccbeecff0909e1fa124b5439ffbbfbc5607e1d6
2024-04-17 23:35:59 +08:00
hiyouga
0c8d6369ac add CodeQwen models
Former-commit-id: 9f6094241391f8f717818c8ba94e11d1791b4a5c
2024-04-17 23:27:22 +08:00
hiyouga
bee796f6b5 fix #3316
Former-commit-id: 7395e9e90a209228ff563ab54319955608850fc3
2024-04-17 22:54:34 +08:00
hiyouga
9f6349a333 fix #3317
Former-commit-id: 7dce1763be4374cf616d96db95ae964ff510a9d6
2024-04-17 22:17:19 +08:00
hiyouga
171a029c5e lint
Former-commit-id: 917d65ce65024d17a5030bc57083a427cfae16d7
2024-04-16 18:21:09 +08:00
hoshi-hiyouga
eaefaa0fe0 Merge pull request #3291 from codemayq/main
support for previewing custom dataset in directory format

Former-commit-id: 40d89152282101a7c08f53e72c2ad7124a0595f3
2024-04-16 18:12:09 +08:00
hiyouga
d301f0a64b Update parser.py
Former-commit-id: 92c2133896c20054db86dd53508c982e39bd5ca0
2024-04-16 18:09:31 +08:00
hiyouga
0a1578e4e3 update readme and gradio version
Former-commit-id: 4029b60ddcbd15b5354503c51178f0f5e7e9aedf
2024-04-16 18:09:16 +08:00
hiyouga
a4167fd925 support badam for all stages
Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
2024-04-16 17:44:48 +08:00
hoshi-hiyouga
42084e08ae Merge pull request #3287 from Ledzy/badam
[Feature] Add BAdam algorithm

Former-commit-id: 10a5e1e65b34b03e5ca2a41bf6ded09a3fb25f0c
2024-04-16 17:32:16 +08:00
hoshi-hiyouga
9d23f5dc89 Update utils.py
Former-commit-id: 01147536b2bb507e87e033fa696e9eb39fe96bbe
2024-04-16 17:30:12 +08:00
hoshi-hiyouga
5978427ae0 Update trainer.py
Former-commit-id: c6163be1444c00dd000f288e2f834968bd932981
2024-04-16 17:29:52 +08:00
hoshi-hiyouga
c7c216069c Update utils.py
Former-commit-id: 7edf4dbed88b8034282f14fd6e0cb6f7f9e5f805
2024-04-16 17:29:30 +08:00
hoshi-hiyouga
cde9d1b917 Update patcher.py
Former-commit-id: 494e6a1e05b38f5ff61d83327303614f53c92e64
2024-04-16 17:29:19 +08:00
hoshi-hiyouga
96213f04b0 Update adapter.py
Former-commit-id: 8f7b75b26f020d8ae85baab7b082475c3bfeb512
2024-04-16 17:28:12 +08:00
hoshi-hiyouga
7ecea08b9b Update parser.py
Former-commit-id: 898239883afc79f03abd0dc276eef901662a9591
2024-04-16 17:27:25 +08:00
hoshi-hiyouga
191971865d Update parser.py
Former-commit-id: 2f3da8169d18b026760cc0ac7dd6141bdd08c932
2024-04-16 17:27:02 +08:00
hoshi-hiyouga
ff4f587dd9 Update finetuning_args.py
Former-commit-id: 3a23d900aea74078f0bc8cf73fac860a4ce3df67
2024-04-16 17:26:30 +08:00
hoshi-hiyouga
de728d0371 Update sft.sh
Former-commit-id: 2b4b1562e91bbb02e345e71b7721da9333c0791b
2024-04-16 17:25:40 +08:00
hoshi-hiyouga
d08e09642d Update requirements.txt
Former-commit-id: 1e45537ca0bb4d49b4147df01122e365b3d617e4
2024-04-16 17:10:17 +08:00
hoshi-hiyouga
351493b183 Update setup.py
Former-commit-id: 5df30ea166aff29d48ff83a22ac6ef1611ce3e35
2024-04-16 17:10:02 +08:00
Jonery
86ab47e121 remove badam from core requirements
Former-commit-id: fa5898944a3867ac5108dd0d579ca0677c87d3d6
2024-04-16 12:25:50 +08:00
Jonery
6dd6b3e396 resolve gradient checkpointing issue.
Former-commit-id: 6df9135d063bb6102f0cbcdf0d702076f5febbae
2024-04-16 12:05:27 +08:00
codingma
5f1418a68b add check
Former-commit-id: 008f6498977c243c80e87242f05c9cf9573541ac
2024-04-16 10:56:39 +08:00
codingma
7b97a79efc support for previewing custom dataset in directory format
Former-commit-id: 501cff38c819f06f15194907ce7e052d5f28025a
2024-04-16 10:43:14 +08:00
hiyouga
ce4f653121 add empty template
Former-commit-id: a325ffa8a668bec354d2636683806acef105e196
2024-04-16 03:10:02 +08:00
hiyouga
b053c6454e update readme
Former-commit-id: 8f233745c3aa7a6ef57f275bec80ee731ff76de3
2024-04-16 02:36:54 +08:00
hiyouga
ebf0f4a77c update readme
Former-commit-id: f9a246572c1ec0e4b36bff237c6523ce629b7000
2024-04-16 02:35:36 +08:00
hiyouga
efa808069a support unsloth 2024.4
Former-commit-id: 14a83f8bc4fe44783252378fce59198194a96bb8
2024-04-16 00:25:03 +08:00
hiyouga
b5c5283dd6 add codegemma
Former-commit-id: 9324176525c2eda22962b0ca1895009b6237e6e3
2024-04-16 00:11:15 +08:00
hiyouga
b638c65519 support cohere commandR #3184
Former-commit-id: e077c36872740f6b2ac255aee9da6c4c70f28977
2024-04-15 23:26:42 +08:00
Jonery
d4d471450f Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
2024-04-15 23:15:27 +08:00
hoshi-hiyouga
3144bdec2c Merge pull request #3254 from marko1616/feature/Add-support-for-CohereForAI/c4ai-command-r-plus
Add template&support for c4ai-command-r/plus (tested)

Former-commit-id: 41d39ec4889abad050820bf153133ac3a11228a3
2024-04-15 22:59:35 +08:00
hoshi-hiyouga
c6d6c4c209 Update template.py
Former-commit-id: 00b8be7dafa65e13b344724a8d3855919ee4f631
2024-04-15 22:58:01 +08:00
hoshi-hiyouga
f5f1589662 Update constants.py
Former-commit-id: 39199f712aa7b7a1c66080d9c84651fd2eb0b425
2024-04-15 22:56:55 +08:00
hiyouga
276f2cb24e update examples
Former-commit-id: 369294b31c8a03a1cafcee83eb31a817007d3c49
2024-04-15 22:14:34 +08:00
marko1616
952b785bb3 change default_system accroding to official template
Former-commit-id: 7ad9029c5e77a87a7c324b8f90b4f80a31a5c78b
2024-04-15 20:45:46 +08:00
marko1616
72dd676208 Revert "Add support for function call(Not strictly following origin)"
This reverts commit dfaa31e991 [formerly 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6].


Former-commit-id: fac9cc6e01dd8f3bc449b656804476e1871326f0
2024-04-15 20:27:09 +08:00
marko1616
dfaa31e991 Add support for function call(Not strictly following origin)
Former-commit-id: 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6
2024-04-15 20:16:52 +08:00
hoshi-hiyouga
86556b1c74 Merge pull request #3261 from khazic/main
Added specimens for single-card full parameter prediction

Former-commit-id: 60df2a9519fbd8215c3afacc831b0cc89006457a
2024-04-15 16:30:57 +08:00
hoshi-hiyouga
0c80751e87 Merge pull request #3276 from liu-zichen/fix_mixtral
fix: turn on output_router_logits of mixtral
Former-commit-id: 07bbaf5c67d00a152e5304e81b15fd9189e7bb99
2024-04-15 15:38:16 +08:00
hiyouga
9338f878a3 fix #3273
Former-commit-id: 3b20c89b342a068356ffc29c3724b645775c65db
2024-04-15 15:32:58 +08:00
liuzc
fde3d91242 fix: mixtral output_router_logits
Former-commit-id: ab3171ea97ec968b972287287ef9ee2502c6d37c
2024-04-15 12:11:49 +08:00
khazic
19adfb88a9 Upgrade README.md
Former-commit-id: 697f768d7185789ee054c94f4f161a65b8a505bc
2024-04-13 20:50:49 +08:00
khazic
daaafa900a Added specimens for single-card full parameter prediction
Former-commit-id: d8d4fb9fa4b0e1950a453682e5e186f34f085dee
2024-04-13 20:45:19 +08:00
marko1616
0dcc9e0bca Typo fix
Former-commit-id: 607625497738b2c8be736be7b0bd5c6f4cbaad5e
2024-04-13 17:30:21 +08:00
marko1616
aeec78b35c Typo fix
Former-commit-id: 51b1e49e288e66c1b0c24ac070201c988fb2a389
2024-04-13 07:52:11 +08:00
marko1616
c991654cb4 Add c4ai-command-r-plus link
Former-commit-id: acaf953ca46eca8fb378067f4ada133654e4f088
2024-04-13 07:32:40 +08:00
marko1616
f328413646 Add template&support(Not tested)
Former-commit-id: 60bb60c4dc30a9641ddb57a44ef126f0768566c4
2024-04-13 04:31:33 +08:00
hiyouga
106a0104da fix #3247
Former-commit-id: bb67c66f80627805b585d157ba807c0ce378d3f2
2024-04-12 17:41:33 +08:00
hiyouga
5486ea09e3 fix model card
Former-commit-id: 920e7149bf2b559c9829aa4b11cfb6d00bbb2f9e
2024-04-12 17:11:59 +08:00
hiyouga
31bbbb6d13 fix #3238
Former-commit-id: 4d7e81ab4722d13bec6ca1af141f94bdc74d0883
2024-04-12 14:28:11 +08:00
hiyouga
1a77de82fa set dev version
Former-commit-id: f6cc76571d2c789675883a18e0db3d0c61f33808
2024-04-11 20:27:34 +08:00
hiyouga
7468f2535c release v0.6.2
Former-commit-id: f92ad0a62d957b595f6a76a5403216b163eb3d17
2024-04-11 20:08:51 +08:00
hiyouga
38e4f22605 Merge branch 'main' of https://github.com/hiyouga/LLaMA-Factory
Former-commit-id: 23ff02c1fd3787daf0bc6ac237c8897d02f726e4
2024-04-10 23:58:18 +08:00
hiyouga
2bc2fe7b5e fix #3225
Former-commit-id: 94110ecf27c32e263f1f2ee61842a3a301b9e089
2024-04-10 23:57:59 +08:00
hoshi-hiyouga
6d0140d8a0 Merge pull request #3201 from kno10/patch-1 and fix #3200
Pass additional_target to unsloth

Former-commit-id: 080a96c52f489fda0d315a77e26c4f6f5d69784a
2024-04-10 00:58:48 +08:00
hoshi-hiyouga
7856f98965 Update adapter.py
Former-commit-id: 720fde3683529ed7e08ac27c7c4598c6bdc30d44
2024-04-10 00:57:51 +08:00
hoshi-hiyouga
e25ddef08c Update adapter.py
Former-commit-id: a84b8d17dbf221259212e81931d80bcdd6284ad7
2024-04-10 00:57:30 +08:00
Erich Schubert
95a4589bbf Pass additional_target to unsloth
Fixes #3200

Former-commit-id: f8f87f5b0549cba6a011749c42064047f82ba577
2024-04-09 17:53:40 +02:00
hiyouga
566d71b7a9 fix quant infer and qwen2moe
Former-commit-id: b75d16767f35c36e2cf2aaab8a3844135085bccf
2024-04-09 17:12:59 +08:00
hiyouga
6030a4a720 tiny fix
Former-commit-id: d8f1ff51d4c920d4d0aeb9d53db29d1efb733c85
2024-04-08 21:28:39 +08:00
hoshi-hiyouga
5dc0cb94d4 Merge pull request #3161 from hiyouga/feature/add-mediatek-model
support Breeze-7B

Former-commit-id: af92ac8b62b919a75673011a1c56832e67882ee8
2024-04-08 20:56:51 +08:00
codingma
325dafcbb0 add empty line
Former-commit-id: 1c6c2e611d10e9fa662e3f4e1e7d23b80ae496cb
2024-04-07 18:28:08 +08:00
codingma
1a8a8b8651 rename template to breeze
Former-commit-id: 1223e6358dab52b4e1505057f1b16fd9d527c79e
2024-04-07 18:27:20 +08:00
hoshi-hiyouga
61a495cb1e Merge pull request #3160 from sliderSun/main
support Qwen1.5-32B

Former-commit-id: 1e5a5882dd494c3e9cf5eae2e0a485ce49d1863c
2024-04-07 18:00:40 +08:00
codingma
75866aa020 rename template to breeze
Former-commit-id: 1d894e7cfb73b8a29dababb554d051bd50e4f01d
2024-04-07 11:39:54 +08:00
codingma
9e4fda326d support https://github.com/hiyouga/LLaMA-Factory/issues/3152
Former-commit-id: 708f0ab4b0aa72e2c73ca36eb9ed058910e43092
2024-04-07 11:34:01 +08:00
sliderSun
1131ddfaff fix spell error
Former-commit-id: e6d36a2e593ebc1193b1735075c4ddb5d9f54990
2024-04-07 10:59:15 +08:00
sliderSun
9f437b5c43 support Qwen1.5-32B
Former-commit-id: c419adf1697b92520342f4ffa697c84bf19ca37d
2024-04-07 10:56:03 +08:00
sliderSun
0cc03d3f05 support Qwen1.5-32B
Former-commit-id: 8f2c67b95a8e177eb4096382417a70cacba38e90
2024-04-07 10:26:13 +08:00
hiyouga
04fc2f78bf update readme
Former-commit-id: 1cf15547e2420a3e5f7a969c21c10c7fbdfc71fe
2024-04-07 00:48:24 +08:00
hiyouga
3ac333fc6a update examples
Former-commit-id: de40ad62ba3d4c74c69de97b39cc79786ac28f0f
2024-04-04 14:48:21 +08:00
hiyouga
a246ac1914 tiny fix
Former-commit-id: 70aceecb27e72095c05462d01f956061669b267e
2024-04-04 02:19:03 +08:00
hiyouga
48ceac845c back to gradio 4.21 and fix chat
Former-commit-id: 695734a40a702ea059d855da54080cc8d161e41a
2024-04-04 02:07:20 +08:00
hiyouga
b1986a06b9 fix bug in latest gradio
Former-commit-id: 44a962862b4a74e50ef5786c8d5719faaa65f63f
2024-04-04 00:55:31 +08:00
hiyouga
43d134ba29 fix requires for windows
Former-commit-id: 5e25fae40b7ea9cfa72717efbe3677199ca9608f
2024-04-03 21:56:43 +08:00
hiyouga
1348f7d860 fix resize vocab at inference #3022
Former-commit-id: c243720b89eec0af2872fa3c7980a0026d893f4d
2024-04-03 18:14:24 +08:00
hiyouga
f6530222f7 fix #3116
Former-commit-id: b7256aa33d761280751518c20f29f9b8ea3fb025
2024-04-03 14:47:59 +08:00
hiyouga
a74a7585e0 update vllm example
Former-commit-id: 2df6d2eacfa27ebc69455696b93649624c1facbe
2024-04-02 22:45:20 +08:00
hiyouga
5bf0cca2b8 update readme
Former-commit-id: 7ea7333b51be6b1120fc0b13675f5a0ac3c5a12b
2024-04-02 22:17:48 +08:00
hiyouga
755b6511ff update examples
Former-commit-id: 2715cfe20f6f4532bebaa47b80ccd5df43d6a490
2024-04-02 21:09:25 +08:00
hiyouga
35621c6089 add zh readme
Former-commit-id: 389a170a4d42c56c71c0e17bbe018c4cb1983b5a
2024-04-02 20:58:45 +08:00
hiyouga
38b59664e6 update examples
Former-commit-id: c078582a759f6bce6e760cd39a05883f7eb194fe
2024-04-02 20:51:21 +08:00
hiyouga
933a084999 update examples
Former-commit-id: bf36b16e48d6438de6d0b2f2bfe33f7895699b9d
2024-04-02 20:41:49 +08:00
hiyouga
c1510d19c7 update readme
Former-commit-id: 9b8e7ccdab167f53fb897e1940562682324e8ff0
2024-04-02 20:37:37 +08:00
hiyouga
2074cf99fb update readme
Former-commit-id: 0c73d3c8a5762a8f119b27322ffd52a61de6fe38
2024-04-02 20:22:11 +08:00
hiyouga
b12176d818 simplify readme
Former-commit-id: 0da6ec2d516326fe9c7583ba71cd1778eb838178
2024-04-02 20:07:43 +08:00
hiyouga
117b67ea30 add moe aux loss control #3085
Former-commit-id: c9187ebc944e2de454ace3304b7d28eabb1b1a81
2024-04-02 14:26:31 +08:00
hiyouga
03e20bb5c6 fix #3022
Former-commit-id: dac2f617bda9470ac8d85c7e9def09cc04970506
2024-04-02 13:58:39 +08:00
hiyouga
0c4a1381a4 Update SECURITY.md
Former-commit-id: e22217c75421a89fd7e2ada62ce0e08245dd05e7
2024-04-01 23:30:03 +08:00
hiyouga
9e14501edb set dev version
Former-commit-id: 922ecae89210e5b8d62d78774f123a6d75c525ba
2024-04-01 23:24:08 +08:00
hiyouga
1dc963caa6 fix #3083
Former-commit-id: ff9a3f73961a362d0ddc22079f80a85465fffda8
2024-04-01 22:53:52 +08:00
hiyouga
85726c91ce add qwen1.5 moe
Former-commit-id: 3ea94f0d12cec25ac694a2c4ae8971c356990b61
2024-04-01 21:49:40 +08:00
hiyouga
40211db275 fix #3077
Former-commit-id: d0340391e8075cff0d84b3ef879c2101b66ca1dc
2024-04-01 21:35:18 +08:00
hiyouga
e7f13098c6 support infer 4bit model on GPUs #3023
Former-commit-id: 950a9dab9055839990656b2b40956792b253573d
2024-04-01 17:34:04 +08:00
hiyouga
61eb3a3d46 update webui
Former-commit-id: e96d260917a35ad2068f7b28b4f0b334b808ccc2
2024-04-01 16:23:28 +08:00
hiyouga
be0a807e8c fix ORPO loss
Former-commit-id: 5544ddde9087f00f9e20b78d0079f20c2f5d1604
2024-04-01 14:42:41 +08:00
hiyouga
52d402e2a9 fix IPO and ORPO loss
Former-commit-id: fc27955732aedbb12003faf19b760e2768b228f2
2024-04-01 14:37:53 +08:00
hiyouga
c5a46f9113 fix plots
Former-commit-id: 81355671296b84d438967463bb2a92934ff31aae
2024-03-31 19:43:48 +08:00
hiyouga
00e17a377c use log1p in orpo loss
https://github.com/huggingface/trl/pull/1491

Former-commit-id: 3b15d495264b00a4f8716bafea334778874963d7
2024-03-31 19:27:08 +08:00
hiyouga
9abd83adb1 update readme
Former-commit-id: 297b01f16ac78cde15a5d85a9a5b82ea20bfaf23
2024-03-31 18:46:34 +08:00
hoshi-hiyouga
f0d2afcf90 Merge pull request #3066 from hiyouga/orpo
support ORPO

Former-commit-id: fd4d3d29a9fae289f3bd0c4ce00656e4ccbec2e1
2024-03-31 18:42:48 +08:00
hiyouga
1aba442bcd support orpo in webui
Former-commit-id: dd5cc78d4fb18dd0a2e9d57f0f046cfe9f0dc2c9
2024-03-31 18:34:59 +08:00
hiyouga
d764cd8736 support ORPO
Former-commit-id: f44a4c27e2461cdaa1b16865f597a31033c0e6d9
2024-03-31 18:29:50 +08:00
hiyouga
526111a303 tiny fix
Former-commit-id: ba4a9b3c01e2f7467fbc5be268f47c0d003caa65
2024-03-31 00:10:29 +08:00
hoshi-hiyouga
b8364046df Merge pull request #3057 from marko1616/bugfix/lora-model-merge
Fix Llama model save for full param train

Former-commit-id: 18303e34d07dbf4c2dd9bac03243a3ed38582515
2024-03-31 00:07:20 +08:00
marko1616
1f617c6e08 fix blank line contains whitespace
Former-commit-id: 7bc3bcc64353d5a1d4870c6a9509b64cff710492
2024-03-30 23:46:55 +08:00
marko1616
a6858a36c0 Fix Llama model save for full param train
Former-commit-id: ca17b5db4f97c3ec9fe2004877f150e8f51ab4b5
2024-03-30 23:45:04 +08:00
hiyouga
6198121923 support save args in webui #2807 #3046
some ideas are borrowed from @marko1616


Former-commit-id: b5a062aa2d4a37670007e8b3dae5b6f5b7ffb15c
2024-03-30 23:09:12 +08:00
hiyouga
b0efebf853 upgrade gradio to 4.21.0
Former-commit-id: 63eecbeb967d849e1d03d8d03fb6421c0ee89257
2024-03-30 20:37:08 +08:00
hiyouga
fbd0584391 release v0.6.1
Former-commit-id: a59d823f554505b2e649e6e111b9dee8306d3ad8
2024-03-29 11:36:08 +08:00
hiyouga
50224b09cc update readme
Former-commit-id: 312d4f90784800dc8db4eaa7d908e6761115bc51
2024-03-28 22:02:32 +08:00
hiyouga
32dcc5a491 add project
Former-commit-id: 0418e9fecb2337b5d1b72e8358adb8aa10803c4b
2024-03-28 20:24:27 +08:00
hiyouga
9408366a36 fix #2982
Former-commit-id: e5e6a0c50c7a1c0052ed6b459450b9735ff2c9a1
2024-03-28 20:22:31 +08:00
hiyouga
f0e564beaa update readme
Former-commit-id: 6b634b5c2dbad827e8cc9850b8d7697c2056532a
2024-03-28 18:35:11 +08:00
hiyouga
14b75a0b93 fix #3010
Former-commit-id: a5e823ae75556eaa3b52ce7a887a6e7838a1eb5f
2024-03-28 18:31:17 +08:00
hiyouga
59e6ebf039 update trainers
Former-commit-id: d0dd6eefed0b86895ed00a7cafb331e5193db645
2024-03-28 18:16:27 +08:00
hoshi-hiyouga
dc540dfaa8 fix ds optimizer
Former-commit-id: 2675127070a1e7584e71039a11c1ebac54ddd1db
2024-03-26 23:39:56 +08:00
hiyouga
587e65e442 fix #2981
Former-commit-id: ede2a913856e52c0a96155705116528d3af15998
2024-03-26 17:53:04 +08:00
hiyouga
a916688723 fix bug
Former-commit-id: f513e1415cc3fe87f600318fba855d1286b6d007
2024-03-26 17:30:12 +08:00
hiyouga
3336422760 fix #2961
Former-commit-id: 616917bb3be7f71073b56ad8c7bc4e164b08b9b5
2024-03-26 17:26:14 +08:00
128 changed files with 3929 additions and 2355 deletions

2
.github/SECURITY.md vendored
View File

@@ -1,6 +1,6 @@
# Reporting Security Issues
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/electron/electron/security/advisories/new) tab.
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/hiyouga/LLaMA-Factory/security/advisories/new) tab.
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.

527
README.md
View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-26-green)](#projects-using-llama-factory)
[![Citation](https://img.shields.io/badge/citation-34-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -43,17 +43,17 @@ Choose your path:
## Features
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
## Benchmark
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
![benchmark](assets/benchmark.svg)
@@ -62,24 +62,36 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
</details>
## Changelog
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See `examples/lora_single_gpu/sft_mllm.sh` for usage.
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm.
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
<details><summary>Full Changelog</summary>
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage.
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `examples/extras/llama_pro` for usage.
@@ -100,7 +112,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
@@ -125,32 +137,37 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
>
> Remember to use the **SAME** template in training and inference.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
@@ -165,9 +182,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> Use `--quantization_bit 4` argument to enable QLoRA.
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## Provided Datasets
@@ -223,6 +238,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -242,12 +258,11 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details>
Please refer to [data/README.md](data/README.md) for details.
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
```bash
@@ -261,8 +276,8 @@ huggingface-cli login
| ------------ | ------- | --------- |
| python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 |
| datasets | 2.14.3 | 2.17.1 |
| transformers | 4.37.2 | 4.39.3 |
| datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 |
@@ -278,36 +293,39 @@ huggingface-cli login
\* *estimated*
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B |
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## Getting Started
### Data Preparation (optional)
### Data Preparation
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
> [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
> Please update `data/dataset_info.json` to use your custom dataset.
### Dependence Installation (optional)
### Dependence Installation
```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd LLaMA-Factory
pip install -r requirements.txt
pip install -e .[metrics]
```
Extra dependencies available: deepspeed, metrics, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
<details><summary>For Windows users</summary>
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
```bash
@@ -316,378 +334,92 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
### Use ModelScope Hub (optional)
</details>
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
```bash
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
```
Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path modelscope/Llama-2-7b-ms \
... # arguments (same as below)
```
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```
### Train on a single GPU
### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
> [!IMPORTANT]
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
#### LLaMA Board GUI
#### Use local environment
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
python src/train_web.py # or python -m llmtuner.webui.interface
```
#### Pre-Training
<details><summary>For Alibaba Cloud users</summary>
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### Supervised Fine-Tuning
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### Reward Modeling
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--top_k 0 \
--top_p 0.9 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
#### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
### Distributed Training
#### Use Huggingface Accelerate
```bash
accelerate launch --config_file config.yaml src/train_bash.py \
--ddp_timeout 180000000 \
... # arguments (same as above)
```
<details><summary>Example config.yaml for LoRA training</summary>
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
</details>
> [!TIP]
> We commend using Accelerate for LoRA tuning.
#### Use DeepSpeed
```bash
deepspeed --num_gpus 8 src/train_bash.py \
--deepspeed ds_config.json \
--ddp_timeout 180000000 \
... # arguments (same as above)
```
<details><summary>Example ds_config.json for full-parameter training with DeepSpeed ZeRO-2</summary>
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"round_robin_gradients": true
}
}
```
</details>
> [!TIP]
> Refer to [examples](examples) for more training scripts.
### Merge LoRA weights and export model
```bash
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora \
--export_dir path_to_export \
--export_size 2 \
--export_legacy_format False
```
> [!WARNING]
> Merging LoRA weights into a quantized model is not supported.
> [!TIP]
> Use `--model_name_or_path path_to_export` solely to use the exported model.
>
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model with AutoGPTQ after merging the LoRA weights.
### Inference with OpenAI-style API
```bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
> [!TIP]
> Visit `http://localhost:8000/docs` for API documentation.
### Inference with command line
```bash
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### Inference with web browser
```bash
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### Evaluation
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template vanilla \
--finetuning_type lora \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4
```
### Predict
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 1 \
--max_samples 100 \
--predict_with_generate \
--fp16
```
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 predict.
> [!TIP]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
### Dockerize Training
#### Get ready
Necessary dockerized environment is needed, such as Docker or Docker Compose.
#### Docker support
#### Use Docker
```bash
docker build -f ./Dockerfile -t llama-factory:latest .
docker run --gpus=all -v ./hf_cache:/root/.cache/huggingface/ -v ./data:/app/data -v ./output:/app/output -p 7860:7860 --shm-size 16G --name llama_factory -d llama-factory:latest
docker run --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \
-v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \
--shm-size 16G \
--name llama_factory \
-d llama-factory:latest
```
#### Docker Compose support
#### Use Docker Compose
```bash
docker compose -f ./docker-compose.yml up -d
```
> [!TIP]
> Details about volume:
> * hf_cache: Utilize Huggingface cache on the host machine. Reassignable if a cache already exists in a different directory.
> * data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
> * output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
<details><summary>Details about volume</summary>
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
</details>
### Train with Command Line Interface
See [examples/README.md](examples/README.md) for usage.
Use `python src/train_bash.py -h` to display arguments description.
### Deploy with OpenAI-style API and vLLM
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \
--infer_backend vllm \
--vllm_enforce_eager
```
### Download from ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
```bash
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
```
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
## Projects using LLaMA Factory
If you have a project that should be incorporated, please contact via email or create a pull request.
<details><summary>Click to show</summary>
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
@@ -709,20 +441,27 @@ docker compose -f ./docker-compose.yml up -d
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
> [!TIP]
> If you have a project that should be incorporated, please contact via email or create a pull request.
</details>
## License
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation
@@ -740,7 +479,7 @@ If this work is helpful, please kindly cite as:
## Acknowledgement
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
## Star History

View File

@@ -5,13 +5,13 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-26-green)](#使用了-llama-factory-的项目)
[![Citation](https://img.shields.io/badge/citation-34-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
👋 加入我们的[微信群](assets/wechat.jpg)。
@@ -23,7 +23,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **本地机器**:请见[如何使用](#如何使用)
## 目录
@@ -43,17 +43,17 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
## 性能指标
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
![benchmark](assets/benchmark.svg)
@@ -62,24 +62,36 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4截断长度=1024
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4截断长度=1024
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1截断长度=1024
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA Factory 的 LoRA 微调中采用 `lora_rank=32`
</details>
## 更新日志
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 `examples/lora_single_gpu/sft_mllm.sh`
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`
[24/03/07] 我们集成**[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%**推理速度。(尚不支持 LoRA请先合并权重。
[24/04/16] 我们支持**[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%**显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
<details><summary>展开日志</summary>
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora`
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `examples/extras/llama_pro`
@@ -100,7 +112,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2。
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn fa2` 参数以启用 FlashAttention-2。
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
@@ -125,32 +137,37 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以得到更好的效果
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat模型请务必使用**对应的模板**。
>
> 请务必在训练和推理时使用**完全一致**的模板。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
@@ -165,9 +182,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## 数据集
@@ -223,6 +238,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -242,12 +258,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details>
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
```bash
@@ -261,8 +276,8 @@ huggingface-cli login
| ------------ | ------- | --------- |
| python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 |
| datasets | 2.14.3 | 2.17.1 |
| transformers | 4.37.2 | 4.39.3 |
| datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 |
@@ -278,36 +293,39 @@ huggingface-cli login
\* *估算值*
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B |
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## 如何使用
### 数据准备(可跳过)
### 数据准备
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集
> [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
### 环境搭建(可跳过)
### 安装依赖
```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd LLaMA-Factory
pip install -r requirements.txt
pip install -e .[metrics]
```
可选的额外依赖项deepspeed、metrics、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
<details><summary>Windows 用户指南</summary>
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
```bash
@@ -316,7 +334,77 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
如果要在 Windows 平台上开启 FlashAttention-2需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
### 使用魔搭社区(可跳过)
</details>
### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
> [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。
#### 使用本地环境
```bash
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
python src/train_web.py # 或 python -m llmtuner.webui.interface
```
<details><summary>阿里云用户指南</summary>
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
```bash
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
</details>
#### 使用 Docker
```bash
docker build -f ./Dockerfile -t llama-factory:latest .
docker run --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \
-v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \
--shm-size 16G \
--name llama_factory \
-d llama-factory:latest
```
#### 使用 Docker Compose
```bash
docker compose -f ./docker-compose.yml up -d
```
<details><summary>数据卷详情</summary>
- hf_cache使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- data宿主机中存放数据集的文件夹路径。
- output将导出目录设置为该路径后即可在宿主机中访问导出后的模型。
</details>
### 利用命令行接口训练
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
您可以执行 `python src/train_bash.py -h` 来查看参数文档。
### 利用 vLLM 部署 OpenAI API
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \
--infer_backend vllm \
--vllm_enforce_eager
```
### 从魔搭社区下载
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
@@ -324,343 +412,14 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
```
接着即可通过指定模型名称来训练对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path modelscope/Llama-2-7b-ms \
... # 参数同下
```
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
```bash
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
```
### 单 GPU 训练
> [!IMPORTANT]
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
#### LLaMA Board GUI
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
#### 预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset wiki_demo \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path path_to_llama_model \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
#### 奖励模型训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
#### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--top_k 0 \
--top_p 0.9 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
> [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
#### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_sft_checkpoint \
--create_new_adapter \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate
```bash
accelerate launch --config_file config.yaml src/train_bash.py \
--ddp_timeout 180000000 \
... # 参数同上
```
<details><summary>使用 Accelerate 进行 LoRA 训练的 config.yaml 示例</summary>
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
> [!TIP]
> 我们推荐使用 Accelerate 进行 LoRA 训练。
#### 使用 DeepSpeed
```bash
deepspeed --num_gpus 8 src/train_bash.py \
--deepspeed ds_config.json \
--ddp_timeout 180000000 \
... # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 ds_config.json 示例</summary>
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"round_robin_gradients": true
}
}
```
</details>
> [!TIP]
> 更多训练脚本请查看 [examples](examples)。
### 合并 LoRA 权重并导出模型
```bash
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora \
--export_dir path_to_export \
--export_size 2 \
--export_legacy_format False
```
> [!WARNING]
> 尚不支持量化模型的 LoRA 权重合并及导出。
> [!TIP]
> 仅使用 `--model_name_or_path path_to_export` 来加载导出后的模型。
>
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 基于 AutoGPTQ 量化模型。
### 使用 OpenAI 风格 API 推理
```bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
> [!TIP]
> 关于 API 文档请见 `http://localhost:8000/docs`。
### 使用命令行推理
```bash
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### 使用浏览器推理
```bash
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
--finetuning_type lora
```
### 模型评估
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template vanilla \
--finetuning_type lora \
--task ceval \
--split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
```
### 模型预测
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 1 \
--max_samples 100 \
--predict_with_generate \
--fp16
```
> [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的预测,请使用 `--per_device_eval_batch_size=1`。
> [!TIP]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
`--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
## 使用了 LLaMA Factory 的项目
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
<details><summary>点击显示</summary>
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
@@ -682,20 +441,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
> [!TIP]
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
</details>
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用
@@ -713,7 +479,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
## 致谢
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[TRL](https://github.com/huggingface/trl)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
## Star History

View File

@@ -18,7 +18,8 @@ If you are using a custom dataset, please provide your dataset definition in the
"history": "the column name in the dataset containing the histories. (default: None)",
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)"
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
@@ -34,6 +35,8 @@ If you are using a custom dataset, please provide your dataset definition in the
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
----
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
```json
@@ -84,6 +87,10 @@ For the preference datasets, the `response` column should be a string list whose
}
```
Remember to set `"ranking": true` for the preference datasets.
----
The dataset in sharegpt format should follow the below format:
```json

View File

@@ -18,7 +18,8 @@
"history": "数据集代表历史对话的表头名称默认None",
"messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None"
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
@@ -34,6 +35,8 @@
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
----
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
```json
@@ -84,6 +87,10 @@
}
```
添加偏好数据集需要额外指定 `"ranking": true`
----
而 sharegpt 格式的数据集按照以下方式组织:
```json

View File

@@ -1 +1 @@
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
a97cf9475291591843976554878568e046d8a46d

View File

@@ -1,5 +1,6 @@
import os
import json
import os
import datasets
@@ -22,31 +23,19 @@ _URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0
class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self):
features = datasets.Features({
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
features = datasets.Features(
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f:
@@ -58,7 +47,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
query = prompt[human_idx + 6 : assist_idx].strip()
prompt = prompt[:human_idx].strip()
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
@@ -67,8 +56,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:")
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
old_query = prompt[human_idx + 6 : assist_idx].strip()
old_resp = prompt[assist_idx + 10 :].strip()
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else:

View File

@@ -1,7 +1,8 @@
import json
import datasets
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
@@ -11,34 +12,24 @@ _URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
})
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))

View File

@@ -1,8 +1,10 @@
import os
import json
import datasets
import os
from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = ""
@@ -14,50 +16,37 @@ _URLS = {
_URL + "harmless-base/train.jsonl.gz",
_URL + "helpful-base/train.jsonl.gz",
_URL + "helpful-online/train.jsonl.gz",
_URL + "helpful-rejection-sampled/train.jsonl.gz"
_URL + "helpful-rejection-sampled/train.jsonl.gz",
],
"test": [
_URL + "harmless-base/test.jsonl.gz",
_URL + "helpful-base/test.jsonl.gz",
_URL + "helpful-online/test.jsonl.gz",
_URL + "helpful-rejection-sampled/test.jsonl.gz"
]
_URL + "helpful-rejection-sampled/test.jsonl.gz",
],
}
class HhRlhfEn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"output": datasets.Sequence(datasets.Value("string")),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
})
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download_and_extract(_URLS)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_path["train"]
}
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepaths": file_path["test"]
}
)
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
]
def _generate_examples(self, filepaths: List[str]):
@@ -70,12 +59,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
rejected = data["rejected"]
assist_idx = rejected.rfind("\n\nAssistant: ")
r_reject = rejected[assist_idx+13:].strip()
r_reject = rejected[assist_idx + 13 :].strip()
assist_idx = chosen.rfind("\n\nAssistant: ")
r_accept = chosen[assist_idx+13:].strip()
r_accept = chosen[assist_idx + 13 :].strip()
human_idx = chosen.rfind("\n\nHuman: ")
query = chosen[human_idx+9:assist_idx].strip()
query = chosen[human_idx + 9 : assist_idx].strip()
prompt = chosen[:human_idx]
history = []
@@ -83,16 +72,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("\n\nAssistant: ")
human_idx = prompt.rfind("\n\nHuman: ")
if human_idx != -1:
old_query = prompt[human_idx+9:assist_idx].strip()
old_resp = prompt[assist_idx+13:].strip()
old_query = prompt[human_idx + 9 : assist_idx].strip()
old_resp = prompt[assist_idx + 13 :].strip()
history.insert(0, (old_query, old_resp))
else:
break
prompt = prompt[:human_idx]
yield key, {
"instruction": query,
"output": [r_accept, r_reject],
"history": history
}
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
key += 1

View File

@@ -1,8 +1,10 @@
import os
import json
import datasets
import os
from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
@@ -24,31 +26,19 @@ _BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jso
class UltraChat(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self):
features = datasets.Features({
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
features = datasets.Features(
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_paths
}
)
]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths:
@@ -56,7 +46,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
for row in f:
try:
data = json.loads(row)
except:
except Exception:
continue
key: int = data["id"]
content: List[str] = data["data"]
@@ -64,8 +54,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
content.pop(-1)
if len(content) < 2:
continue
conversations = [{
"from": "human" if i % 2 == 0 else "gpt",
"value": content[i]
} for i in range(len(content))]
conversations = [
{"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
]
yield key, {"conversations": conversations}

View File

@@ -10,6 +10,8 @@ services:
- ./hf_cache:/root/.cache/huggingface/
- ./data:/app/data
- ./output:/app/output
environment:
- CUDA_VISIBLE_DEVICES=0
ports:
- "7860:7860"
ipc: host

50
examples/README.md Normal file
View File

@@ -0,0 +1,50 @@
We provide diverse examples about fine-tuning LLMs.
```
examples/
├── lora_single_gpu/
│ ├── pretrain.sh: Do continuous pre-training using LoRA
│ ├── sft.sh: Do supervised fine-tuning using LoRA
│ ├── reward.sh: Do reward modeling using LoRA
│ ├── ppo.sh: Do PPO training using LoRA
│ ├── dpo.sh: Do DPO training using LoRA
│ ├── orpo.sh: Do ORPO training using LoRA
│ ├── sft_mllm.sh: Do supervised fine-tuning on multimodal data using LoRA
│ ├── prepare.sh: Save tokenized dataset
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
├── qlora_single_gpu/
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models using QLoRA
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models using QLoRA
│ ├── awq.sh: Fine-tune 4-bit AWQ models using QLoRA
│ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
├── lora_multi_gpu/
│ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
│ ├── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
│ └── ds_zero3.sh: Fine-tune model with DeepSpeed ZeRO-3 using LoRA (weight sharding)
├── full_multi_gpu/
│ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
│ ├── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
│ └── predict.sh: Do parallel batch predict and compute BLEU and ROUGE scores after full tuning
├── merge_lora/
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
├── inference/
│ ├── cli_demo.sh: Chat with fine-tuned model in the CLI with LoRA adapters
│ ├── api_demo.sh: Chat with fine-tuned model in an OpenAI-style API with LoRA adapters
│ ├── web_demo.sh: Chat with fine-tuned model in the Web browser with LoRA adapters
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
└── extras/
├── galore/
│ └── sft.sh: Fine-tune model with GaLore
├── badam/
│ └── sft.sh: Fine-tune model with BAdam
├── loraplus/
│ └── sft.sh: Fine-tune model using LoRA+
├── mod/
│ └── sft.sh: Fine-tune model using Mixture-of-Depths
├── llama_pro/
│ ├── expand.sh: Expand layers in the model
│ └── sft.sh: Fine-tune the expanded model
└── fsdp_qlora/
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
```

50
examples/README_zh.md Normal file
View File

@@ -0,0 +1,50 @@
我们提供了多样化的大模型微调示例脚本。
```
examples/
├── lora_single_gpu/
│ ├── pretrain.sh: 基于 LoRA 进行增量预训练
│ ├── sft.sh: 基于 LoRA 进行指令监督微调
│ ├── reward.sh: 基于 LoRA 进行奖励模型训练
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
│ ├── sft_mllm.sh: 基于 LoRA 进行多模态指令监督微调
│ ├── prepare.sh: 保存预处理后的数据集
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
├── qlora_single_gpu/
│ ├── bitsandbytes.sh: 基于 QLoRA 微调 4/8 比特 BNB 模型
│ ├── gptq.sh: 基于 QLoRA 微调 4/8 比特 GPTQ 模型
│ ├── awq.sh: 基于 QLoRA 微调 4 比特 AWQ 模型
│ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
├── lora_multi_gpu/
│ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
│ ├── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
│ └── ds_zero3.sh: 使用 DeepSpeed ZeRO-3 进行 LoRA 训练(拆分权重)
├── full_multi_gpu/
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
│ ├── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
│ └── predict.sh: 基于全量训练进行多卡批量预测并计算 BLEU 和 ROUGE 分数
├── merge_lora/
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
│ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型
├── inference/
│ ├── cli_demo.sh: 启动 LoRA 模型的命令行推理接口
│ ├── api_demo.sh: 启动 LoRA 模型的 OpenAI 风格 API
│ ├── web_demo.sh: 启动 LoRA 模型的浏览器推理接口
│ └── evaluate.sh: 在 MMLU/CMMLU/C-Eval 数据集上评测 LoRA 模型
└── extras/
├── galore/
│ └── sft.sh: 使用 GaLore 训练模型
├── badam/
│ └── sft.sh: 使用 BAdam 训练模型
├── loraplus/
│ └── sft.sh: 使用 LoRA+ 训练模型
├── mod/
│ └── sft.sh: 使用深度混合训练模型
├── llama_pro/
│ ├── expand.sh: 扩展模型中的层
│ └── sft.sh: 训练扩展后的模型
└── fsdp_qlora/
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
```

View File

@@ -9,14 +9,14 @@ fsdp_config:
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []

View File

@@ -8,8 +8,8 @@ main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []

View File

@@ -6,8 +6,8 @@ gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
num_machines: 1 # the number of nodes
num_processes: 4 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []

View File

@@ -8,8 +8,8 @@ main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []

View File

@@ -8,14 +8,18 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--use_badam \
--badam_switch_mode descending \
--badam_switch_block_every 50 \
--badam_verbose 2 \
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
@@ -28,4 +32,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16
--pure_bf16

View File

@@ -1,25 +1,29 @@
#!/bin/bash
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file ../../accelerate/fsdp_config.yaml \
../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--use_galore \
--galore_layerwise \
--galore_target mlp,self_attn \
--galore_rank 128 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../../saves/LLaMA2-70B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
@@ -31,5 +35,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--ddp_timeout 180000000 \
--quantization_bit 4 \
--plot_loss \
--fp16

View File

@@ -8,11 +8,11 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--optim adamw_8bit \
--use_galore \
--galore_layerwise \
--galore_target mlp,self_attn \
--galore_rank 128 \
--galore_scale 2.0 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \

View File

@@ -9,6 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--loraplus_lr_ratio 16.0 \
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
--overwrite_cache \
--overwrite_output_dir \
@@ -29,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16 \
--loraplus_lr_ratio 16.0
--fp16

View File

@@ -8,15 +8,16 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--optim adamw_8bit \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--mixture_of_depths convert \
--output_dir ../../../saves/LLaMA2-7B/mod/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--gradient_accumulation_steps 8 \
--optim paged_adamw_8bit \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \

View File

@@ -1,5 +0,0 @@
```bash
pip install git+https://github.com/huggingface/transformers.git
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
```

View File

@@ -33,6 +33,6 @@ python -m torch.distributed.run \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--ddp_timeout 1800000 \
--ddp_timeout 180000000 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,20 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type full \
--output_dir ../../saves/LLaMA2-7B/full/predict \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_eval_batch_size 1 \
--max_samples 20 \
--predict_with_generate

View File

@@ -27,6 +27,6 @@ deepspeed --num_gpus 4 ../../src/train_bash.py \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--ddp_timeout 1800000 \
--ddp_timeout 180000000 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,7 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python ../../src/api_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \
--finetuning_type lora

View File

@@ -0,0 +1,7 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/cli_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \
--finetuning_type lora

View File

@@ -0,0 +1,12 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template fewshot \
--finetuning_type lora \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4

View File

@@ -0,0 +1,8 @@
#!/bin/bash
# add `--visual_inputs True` to load MLLM
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \
--finetuning_type lora

View File

@@ -1,33 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file ../accelerate/fsdp_config.yaml \
../../src/train_bash.py \
deepspeed --num_gpus 4 ../../src/train_bash.py \
--deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-70B/lora/sft \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--quantization_bit 4 \
--ddp_timeout 180000000 \
--plot_loss \
--fp16

View File

@@ -1,4 +1,5 @@
#!/bin/bash
# also launch it on slave machine using slave_config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/master_config.yaml \
@@ -30,6 +31,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--ddp_timeout 1800000 \
--ddp_timeout 180000000 \
--plot_loss \
--fp16

View File

@@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \
--stage sft \
@@ -30,6 +30,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--ddp_timeout 1800000 \
--ddp_timeout 180000000 \
--plot_loss \
--fp16

View File

@@ -1,8 +0,0 @@
Usage:
- `pretrain.sh`: do pre-train (optional)
- `sft.sh`: do supervised fine-tune
- `reward.sh`: do reward modeling (must after sft.sh)
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
- `dpo.sh`: do DPO training (must after sft.sh)
- `predict.sh`: do predict (must after sft.sh and dpo.sh)

View File

@@ -6,7 +6,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--create_new_adapter \
--dataset comparison_gpt4_en \
--dataset orca_rlhf \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \

View File

@@ -0,0 +1,32 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage orpo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset orca_rlhf \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/orpo \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--max_samples 1000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -0,0 +1,18 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES= python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--max_samples 3000 \
--tokenized_path ../../saves/datasets/sft

View File

@@ -6,7 +6,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--create_new_adapter \
--dataset comparison_gpt4_en \
--dataset orca_rlhf \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \

View File

@@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--visual_inputs \
--dataset mllm_demo \
--dataset_dir ../../data \
--template vicuna \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft_mllm \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 100.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -1,4 +0,0 @@
Usage:
- `merge.sh`: merge the lora weights
- `quantize.sh`: quantize the model with AutoGPTQ (must after merge.sh, optional)

View File

@@ -1,4 +1,5 @@
#!/bin/bash
# DO NOT use quantized model or quantization_bit when merging lora weights
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
@@ -7,4 +8,5 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--finetuning_type lora \
--export_dir ../../models/llama2-7b-sft \
--export_size 2 \
--export_device cpu \
--export_legacy_format False

View File

@@ -1,4 +1,5 @@
#!/bin/bash
# NEED TO run `merge.sh` before using this script
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path ../../models/llama2-7b-sft \

View File

@@ -2,9 +2,9 @@ torch>=1.13.1
transformers>=4.37.2
datasets>=2.14.3
accelerate>=0.27.2
peft>=0.9.0
peft>=0.10.0
trl>=0.8.1
gradio>=3.38.0,<4.0.0
gradio>=4.0.0
scipy
einops
sentencepiece
@@ -15,4 +15,4 @@ fastapi
sse-starlette
matplotlib
fire
galore-torch
packaging

View File

@@ -15,7 +15,7 @@ from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
from llmtuner.model import load_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@@ -32,7 +32,7 @@ def calculate_lr(
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage=stage,
model_name_or_path=model_name_or_path,
@@ -44,8 +44,9 @@ def calculate_lr(
overwrite_cache=True,
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":

View File

@@ -10,7 +10,7 @@ from tqdm import tqdm
from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
from llmtuner.model import load_tokenizer
def length_cdf(
@@ -20,7 +20,7 @@ def length_cdf(
template: Optional[str] = "default",
interval: Optional[int] = 1000,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage="sft",
model_name_or_path=model_name_or_path,
@@ -32,8 +32,8 @@ def length_cdf(
overwrite_cache=True,
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):

View File

@@ -1,114 +0,0 @@
# coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
internlm2_config_dict: Dict[str, Any] = json.load(f)
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
internlm2_state_dict.update(shard_weight)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
if "output" in key:
llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
elif "wqkv" in key:
num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
q_size : q_size + kv_size, ...
]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "ffn_norm" in key:
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
elif "w1" in key:
llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value
elif "w2" in key:
llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value
elif "w3" in key:
llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
else:
llama2_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None)
llama2_config_dict.pop("bias", None)
llama2_config_dict.pop("rope_scaling", None)
llama2_config_dict["model_type"] = "llama"
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir)
if __name__ == "__main__":
fire.Fire(llamafy_internlm2)

View File

@@ -20,15 +20,17 @@ def get_requires():
extra_require = {
"deepspeed": ["deepspeed"],
"deepspeed": ["deepspeed>=0.10.0"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
"vllm": ["vllm>=0.3.3"],
"galore": ["galore-torch"],
"badam": ["badam"],
"vllm": ["vllm>=0.4.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["tiktoken", "transformers_stream_generator"],
"modelscope": ["modelscope"],
"quality": ["ruff"],
}

View File

@@ -2,8 +2,7 @@ from llmtuner import Evaluator
def main():
evaluator = Evaluator()
evaluator.eval()
Evaluator().eval()
if __name__ == "__main__":

View File

@@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.6.0"
__version__ = "0.7.0"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
else:
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:

View File

@@ -1,6 +1,6 @@
import time
from enum import Enum, unique
from typing import List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@@ -39,6 +39,17 @@ class Function(BaseModel):
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
@@ -47,7 +58,8 @@ class FunctionCall(BaseModel):
class ChatMessage(BaseModel):
role: Role
content: str
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
@@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: list = []
tools: Optional[List[FunctionAvailable]] = None
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None

View File

@@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass
class Response:
@@ -49,6 +47,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]: ...
@@ -58,6 +57,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...

View File

@@ -8,6 +8,8 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from numpy.typing import NDArray
from .base_engine import BaseEngine, Response
@@ -36,9 +38,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
return task.result()
async def achat(
@@ -46,18 +49,20 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, **input_kwargs)
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, **input_kwargs)
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -70,9 +75,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
yield new_token
def get_scores(

View File

@@ -9,12 +9,14 @@ from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import load_model_and_tokenizer
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper
from ..data import Template
@@ -30,24 +32,32 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
@staticmethod
def _process_args(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
@@ -94,6 +104,11 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(),
)
if processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
return gen_kwargs, prompt_length
@staticmethod
@@ -101,15 +116,17 @@ class HuggingfaceEngine(BaseEngine):
def _chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@@ -134,15 +151,17 @@ class HuggingfaceEngine(BaseEngine):
def _stream_chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@@ -198,6 +217,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -207,11 +227,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:
@@ -223,6 +245,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@@ -232,11 +255,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:

View File

@@ -1,19 +1,24 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from transformers.utils.versions import require_version
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count
from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available
from ..model import load_tokenizer
from ..model import load_config, load_tokenizer
from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if TYPE_CHECKING:
import torch
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -25,32 +30,59 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
config = load_config(model_args) # may download model from ms hub
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
infer_dtype = str(infer_dtype).split(".")[-1]
self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path,
trust_remote_code=True,
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count() or 1,
gpu_memory_utilization=model_args.vllm_gpu_util,
disable_log_stats=True,
disable_log_requests=True,
enforce_eager=model_args.vllm_enforce_eager,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
self.tokenizer = load_tokenizer(model_args)
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if model_args.visual_inputs:
# TODO: auto derive from config
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
self.image_feature_size = 576
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_input_shape"] = "1,3,336,336"
engine_args["image_feature_size"] = self.image_feature_size
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
self.lora_request = None
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
@@ -94,8 +126,21 @@ class VllmEngine(BaseEngine):
max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True,
)
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate(
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
)
return result_generator
@@ -107,10 +152,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, **input_kwargs)
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for request_output in generator:
final_output = request_output
@@ -132,10 +178,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, **input_kwargs)
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text

View File

@@ -1,6 +1,15 @@
from .collator import PairwiseDataCollatorWithPadding
from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
__all__ = [
"PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",
"templates",
"Role",
"split_dataset",
]

View File

@@ -1,3 +1,4 @@
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
@@ -13,8 +14,23 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
return outputs
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
@@ -44,12 +60,16 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
@@ -84,6 +104,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
@@ -96,12 +117,13 @@ def align_dataset(
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
tools: "...",
images: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
@@ -114,6 +136,7 @@ def align_dataset(
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {}

View File

@@ -6,12 +6,15 @@ from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
r"""
Masks out the input ids except for the responses.
"""
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
@@ -43,12 +46,6 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
)
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = super().__call__(concatenated_features)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@@ -1,11 +1,12 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
@@ -15,7 +16,7 @@ from .utils import checksum, merge_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
@@ -80,7 +81,9 @@ def load_single_dataset(
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
@@ -112,22 +115,23 @@ def load_single_dataset(
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@@ -138,12 +142,15 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
@@ -156,10 +163,13 @@ def get_dataset(
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
exit(0)
if training_args.should_log:
try:

View File

@@ -28,6 +28,7 @@ class DatasetAttr:
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
system: Optional[str] = None
images: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
@@ -53,12 +54,19 @@ class DatasetAttr:
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
if data_args.dataset is not None:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
else:
dataset_names = []
if data_args.dataset_dir == "ONLINE":
dataset_info = None
else:
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
if len(dataset_names) != 0:
raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
@@ -69,6 +77,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None:
load_from = "ms_hub" if use_modelscope() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
@@ -92,7 +106,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system"]
column_names = ["system", "images"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:

View File

@@ -1,14 +1,22 @@
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from ..extras.packages import is_pillow_available
from .utils import Role
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
@@ -18,21 +26,30 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0]
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
@@ -46,18 +63,25 @@ def preprocess_pretrain_dataset(
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
@@ -87,14 +111,16 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
@@ -139,17 +165,24 @@ def preprocess_packed_supervised_dataset(
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
@@ -170,22 +203,32 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
@@ -212,6 +255,8 @@ def preprocess_pairwise_dataset(
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
@@ -242,34 +287,54 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)

View File

@@ -343,7 +343,7 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
template = templates["empty"] # placeholder
else:
template = templates.get(name, None)
if template is None:
@@ -385,7 +385,8 @@ _register_template(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
),
)
@@ -414,7 +415,7 @@ _register_template(
_register_template(
name="baichuan",
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]),
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
@@ -441,6 +442,18 @@ _register_template(
)
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
default_system=(
"You are a helpful AI assistant built by MediaTek Research. "
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
),
efficient_eos=True,
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@@ -490,6 +503,7 @@ _register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
@@ -500,6 +514,7 @@ _register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
@@ -514,6 +529,21 @@ _register_template(
)
_register_template(
name="cohere",
format_user=StringFormatter(
slots=[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@@ -522,6 +552,32 @@ _register_template(
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)\nThis is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@@ -554,6 +610,13 @@ _register_template(
)
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -562,10 +625,20 @@ _register_template(
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
efficient_eos=True,
force_system=True,
@@ -623,9 +696,36 @@ _register_template(
)
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
default_system="You are a helpful assistant.",
stop_words=["<|eot_id|>"],
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
@@ -657,10 +757,23 @@ _register_template(
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
@@ -687,11 +800,6 @@ _register_template(
)
_register_template(
name="vanilla",
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
@@ -762,7 +870,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",
default_system="You are Zephyr, a helpful assistant.",
)

View File

@@ -44,7 +44,7 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
@@ -78,9 +78,9 @@ def split_dataset(
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size

View File

@@ -14,16 +14,17 @@ from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import load_model_and_tokenizer
from ..model import load_model, load_tokenizer
from .template import get_eval_template
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES

View File

@@ -1,14 +1,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
@@ -16,22 +12,29 @@ class EvalTemplate:
answer: str
prefix: str
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages
@@ -39,7 +42,7 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
@@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
return eval_template
register_eval_template(
_register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
@@ -58,10 +61,10 @@ register_eval_template(
)
register_eval_template(
_register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n",
prefix=" ",
)

View File

@@ -58,6 +58,14 @@ class LogCallback(TrainerCallback):
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
@@ -112,8 +120,12 @@ class LogCallback(TrainerCallback):
r"""
Event called after logging the last logs.
"""
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict(
current_steps=self.cur_steps,
@@ -122,6 +134,7 @@ class LogCallback(TrainerCallback):
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,

View File

@@ -28,6 +28,10 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
MLLM_LIST = ["LLaVA1.5"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
@@ -39,9 +43,14 @@ TRAINING_STAGES = {
"Reward Modeling": "rm",
"PPO": "ppo",
"DPO": "dpo",
"ORPO": "orpo",
"Pre-Training": "pt",
}
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
@@ -167,6 +176,19 @@ register_model_group(
)
register_model_group(
models={
"Breeze-7B": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
},
"Breeze-7B-Chat": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
},
},
template="breeze",
)
register_model_group(
models={
"ChatGLM2-6B-Chat": {
@@ -226,6 +248,44 @@ register_model_group(
)
register_model_group(
models={
"CommandR-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
},
"CommandR-Plus-104B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
},
"CommandR-35B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
},
"CommandR-Plus-104B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
},
},
template="cohere",
)
register_model_group(
models={
"DBRX-132B-Base": {
DownloadSource.DEFAULT: "databricks/dbrx-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
},
"DBRX-132B-Chat": {
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
},
},
module="Wqkv",
template="dbrx",
)
register_model_group(
models={
"DeepSeek-LLM-7B-Base": {
@@ -246,9 +306,11 @@ register_model_group(
},
"DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
},
"DeepSeek-Math-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
},
"DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
@@ -347,6 +409,23 @@ register_model_group(
)
register_model_group(
models={
"CodeGemma-2B": {
DownloadSource.DEFAULT: "google/codegemma-2b",
},
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
},
template="gemma",
)
register_model_group(
models={
"InternLM-7B": {
@@ -394,6 +473,16 @@ register_model_group(
)
register_model_group(
models={
"Jambda-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
},
)
register_model_group(
models={
"LingoWhale-8B": {
@@ -460,14 +549,54 @@ register_model_group(
register_model_group(
models={
"Mistral-7B": {
"LLaMA3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
},
"LLaMA3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
},
"LLaMA3-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
},
"LLaMA3-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
},
},
template="llama3",
)
register_model_group(
models={
"LLaVA1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
},
"LLaVA1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
)
register_model_group(
models={
"Mistral-7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Chat": {
"Mistral-7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2": {
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
},
"Mistral-7B-v0.2-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
@@ -479,14 +608,21 @@ register_model_group(
register_model_group(
models={
"Mixtral-8x7B": {
"Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-Chat": {
"Mixtral-8x7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
},
"Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
},
},
template="mistral",
)
@@ -495,18 +631,15 @@ register_model_group(
register_model_group(
models={
"OLMo-1B": {
DownloadSource.DEFAULT: "allenai/OLMo-1B",
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
},
"OLMo-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B",
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
},
"OLMo-7B-Chat": {
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
"OLMo-1.7-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
},
},
module="att_proj",
template="olmo",
)
@@ -514,7 +647,7 @@ register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
}
},
template="openchat",
@@ -562,6 +695,22 @@ register_model_group(
)
register_model_group(
models={
"Phi3-3.8B-4k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-4k-instruct",
},
"Phi3-3.8B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-128k-instruct",
},
},
module="qkv_proj",
template="phi",
)
register_model_group(
models={
"Qwen-1.8B": {
@@ -656,10 +805,26 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
},
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-Code-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
@@ -680,10 +845,26 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-Code-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
@@ -724,6 +905,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
},
"Qwen1.5-32B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
@@ -732,6 +917,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"Qwen1.5-Code-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
)
@@ -765,12 +962,15 @@ register_model_group(
models={
"StarCoder2-3B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
},
"StarCoder2-7B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
},
"StarCoder2-15B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
},
}
)
@@ -793,17 +993,53 @@ register_model_group(
register_model_group(
models={
"XuanYuan-6B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
},
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
},
"XuanYuan-2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
},
"XuanYuan-6B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan-2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
},
"XuanYuan-6B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
},
"XuanYuan-6B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
"XuanYuan-2-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
},
"XuanYuan-2-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
},
},
template="xuanyuan",
@@ -840,6 +1076,30 @@ register_model_group(
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
},
template="xverse",
)
@@ -932,21 +1192,9 @@ register_model_group(
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
"Zephyr-141B-ORPO-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
},
},
template="zephyr",
)
register_model_group(
models={
"Atom-7B": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
},
"Atom-7B-Chat": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
},
},
template="atom",
)

View File

@@ -64,7 +64,7 @@ def check_dependencies() -> None:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
@@ -83,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else:
num_bytes = 1
@@ -192,6 +194,13 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def has_tokenized_data(path: os.PathLike) -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def torch_gc() -> None:
r"""
Collects GPU memory.
@@ -202,17 +211,15 @@ def torch_gc() -> None:
torch.cuda.ipc_collect()
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
return
return model_args.model_name_or_path
try:
from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
)
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")

View File

@@ -1,16 +1,23 @@
import importlib.metadata
import importlib.util
from typing import TYPE_CHECKING
from packaging import version
if TYPE_CHECKING:
from packaging.version import Version
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def _get_package_version(name: str) -> str:
def _get_package_version(name: str) -> "Version":
try:
return importlib.metadata.version(name)
return version.parse(importlib.metadata.version(name))
except Exception:
return "0.0.0"
return version.parse("0.0.0")
def is_fastapi_availble():
@@ -18,13 +25,17 @@ def is_fastapi_availble():
def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
def is_galore_available():
return _is_package_available("galore_torch")
def is_gradio_available():
return _is_package_available("gradio")
def is_jieba_available():
return _is_package_available("jieba")
@@ -37,6 +48,10 @@ def is_nltk_available():
return _is_package_available("nltk")
def is_pillow_available():
return _is_package_available("PIL")
def is_requests_available():
return _is_package_available("requests")
@@ -45,14 +60,14 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_sdpa_available():
return _get_package_version("torch") > version.parse("2.1.1")
def is_starlette_available():
return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

@@ -1,38 +0,0 @@
import torch
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
def patch_mixtral_replace_moe_impl() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -52,6 +52,6 @@ def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace(os.path.sep, "_")))
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
plt.savefig(figure_path, format="png", dpi=100)
print("Figure saved at:", figure_path)

View File

@@ -26,11 +26,11 @@ class DataArguments:
)
cutoff_len: int = field(
default=1024,
metadata={"help": "The cutoff length of the model inputs after tokenization."},
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
reserved_label_len: int = field(
default=1,
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
)
train_on_prompt: bool = field(
default=False,
@@ -84,9 +84,9 @@ class DataArguments:
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
},
)
cache_path: Optional[str] = field(
tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the pre-processed datasets."},
metadata={"help": "Path to save or load the tokenized datasets."},
)
def __post_init__(self):

View File

@@ -102,10 +102,18 @@ class RLHFArguments:
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_label_smoothing: float = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
dpo_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
@@ -114,10 +122,6 @@ class RLHFArguments:
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
)
ppo_score_norm: bool = field(
default=False,
metadata={"help": "Use score normalization in PPO training."},
@@ -168,7 +172,7 @@ class GaloreArguments:
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."},
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="all",
@@ -200,7 +204,54 @@ class GaloreArguments:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_update_ratio: float = field(
default=0.0,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": """The mode of the mask for BAdam optimizer. \
`adjacent` means that the trainable parameters are adjacent to each other, \
`scatter` means that trainable parameters are randomly choosed from the weight."""
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": """The verbosity level of BAdam optimizer. \
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
@@ -209,7 +260,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
@@ -248,12 +299,18 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"

View File

@@ -31,11 +31,11 @@ class GeneratingArguments:
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
)
max_length: int = field(
default=512,
default=1024,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
)
max_new_tokens: int = field(
default=512,
default=1024,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
)
repetition_penalty: float = field(

View File

@@ -22,7 +22,7 @@ class ModelArguments:
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=False,
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
@@ -33,6 +33,10 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
@@ -53,22 +57,38 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: bool = field(
default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."},
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
visual_inputs: bool = field(
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
@@ -121,6 +141,10 @@ class ModelArguments:
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: str = field(
default="cpu",
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
@@ -158,9 +182,15 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."

View File

@@ -11,8 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies
from ..extras.packages import is_unsloth_available
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -68,6 +67,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
@@ -75,6 +77,35 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.0", "To fix: pip install vllm>=0.4.0")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
@@ -119,20 +150,23 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if (
finetuning_args.stage == "ppo"
and training_args.report_to
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.use_dora:
if model_args.quantization_bit is not None:
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
if model_args.use_unsloth:
if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.pure_bf16:
@@ -149,21 +183,33 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed.")
if (
finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.quantization_bit is None
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
@@ -233,6 +279,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
@@ -264,17 +311,24 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.adapter_name_or_path is not None:
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support quantization.")
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
_verify_model_args(model_args, finetuning_args)
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None:
model_args.device_map = {"": torch.device(model_args.export_device)}
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args
@@ -292,6 +346,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"

View File

@@ -1,10 +1,11 @@
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
from .utils import find_all_linear_modules, load_valuehead_params
from .loader import load_config, load_model, load_tokenizer
from .utils.misc import find_all_linear_modules
from .utils.valuehead import load_valuehead_params
__all__ = [
"load_config",
"load_model",
"load_model_and_tokenizer",
"load_tokenizer",
"load_valuehead_params",
"find_all_linear_modules",

View File

@@ -5,11 +5,13 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
@@ -18,7 +20,11 @@ logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -32,9 +38,12 @@ def init_adapter(
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
@@ -66,6 +75,8 @@ def init_adapter(
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
@@ -79,7 +90,7 @@ def init_adapter(
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
@@ -100,6 +111,10 @@ def init_adapter(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
@@ -116,8 +131,14 @@ def init_adapter(
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder
model,
adapter_to_resume,
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
@@ -129,34 +150,45 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)

View File

@@ -1,17 +1,20 @@
from typing import TYPE_CHECKING, Any, Dict, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
from .utils.misc import register_autoclass
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .utils.unsloth import load_unsloth_pretrained_model
from .utils.valuehead import load_valuehead_params
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments
@@ -19,7 +22,18 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
@@ -28,15 +42,14 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
}
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer. Must before load_model.
Loads pretrained tokenizer.
Note: including inplace operation of model_args.
"""
try_download_model_from_ms(model_args)
init_kwargs = _get_init_kwargs(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
@@ -44,8 +57,41 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
return tokenizer
if model_args.visual_inputs:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
else:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model(
@@ -56,45 +102,42 @@ def load_model(
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model. Must after load_tokenizer.
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
model = None
if is_trainable and model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
unsloth_kwargs = {
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
}
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.adapter_name_or_path:
model_args.adapter_name_or_path = None
logger.warning("Unsloth does not support loading adapters.")
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model is None:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
patch_model(model, tokenizer, model_args, is_trainable)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
@@ -110,9 +153,6 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
model.eval()
for param in model.parameters():
if param.device.type == "cuda":
param.data = param.data.to(model_args.compute_dtype)
else:
model.train()
@@ -134,17 +174,3 @@ def load_model(
)
return model
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
"""
tokenizer = load_tokenizer(model_args)
model = load_model(tokenizer, model_args, finetuning_args, is_trainable, add_valuehead)
return model, tokenizer

View File

@@ -1,24 +1,22 @@
import math
import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Dict
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
from .utils import QuantizationMethod
from ..extras.misc import infer_optim_dtype
from .utils.attention import configure_attn_implementation, print_attn_implementation
from .utils.checkpointing import prepare_model_for_training
from .utils.embedding import resize_embedding_layer
from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
from .utils.visual import autocast_projector_dtype
if TYPE_CHECKING:
@@ -29,244 +27,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def _configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
) -> None:
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
return
logger.info("Using FlashAttention-2 for faster training and inference.")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", "flash_attention_2")
else:
init_kwargs["attn_implementation"] = "flash_attention_2"
else:
init_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
@@ -280,60 +40,76 @@ def patch_config(
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
add_valuehead: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
configure_attn_implementation(config, model_args)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
if add_valuehead:
configure_valuehead(config)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(config, model_args, init_kwargs)
_configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs: # quant models cannot use auto device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map
if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
add_valuehead: bool,
) -> None:
gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and (
(gen_config.temperature is not None and gen_config.temperature != 1.0)
or (gen_config.top_p is not None and gen_config.top_p != 1.0)
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if add_valuehead:
prepare_valuehead_model(model)
if model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer)
resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable:
_prepare_model_for_training(model, model_args)
prepare_model_for_training(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if is_trainable:
patch_mixtral_replace_moe_impl()
if not model_args.use_unsloth:
print_attn_implementation(model.config)
try:
model.add_model_tags(["llama-factory"])

View File

@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla Attention implementation.")

View File

@@ -0,0 +1,94 @@
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@@ -0,0 +1,58 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))

View File

@@ -1,5 +1,5 @@
import math
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
@@ -7,19 +7,28 @@ from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
def llama_torch_attn_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -39,10 +48,11 @@ def llama_torch_attn_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -69,8 +79,9 @@ def llama_torch_attn_forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -97,8 +108,8 @@ def llama_torch_attn_forward(
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -117,7 +128,6 @@ def llama_flash_attn_forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -134,9 +144,10 @@ def llama_flash_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
@@ -192,7 +203,115 @@ def llama_flash_attn_forward(
return attn_output, attn_weights, past_key_value
def apply_llama_patch() -> None:
require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1")
LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
return llama_attention_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, :groupsz]
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _apply_llama_patch() -> None:
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")

View File

@@ -1,38 +1,21 @@
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, List
import torch
from transformers import PreTrainedModel
from transformers.utils import cached_file
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ...extras.logging import get_logger
from .quantization import QuantizationMethod
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import ModelArguments
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora.
Finds all available modules to apply lora or galore.
"""
quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None:
@@ -47,6 +30,8 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
elif model.config.model_type == "internlm2":
output_layer_names.append("output")
module_names = set()
for name, module in model.named_modules():
@@ -84,34 +69,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
return module_names
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
from MoD import AutoMoDModelForCausalLM
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
def convert_pretrained_model_to_mod(
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
) -> "PreTrainedModel":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
return model

View File

@@ -0,0 +1,53 @@
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled():
return
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)

View File

@@ -0,0 +1,146 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ...hparams import ModelArguments
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

View File

@@ -0,0 +1,47 @@
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)

View File

@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"use_gradient_checkpointing": "unsloth",
}
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False
return model
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
if not is_trainable:
FastLanguageModel.for_inference(model)
return model

View File

@@ -0,0 +1,59 @@
from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_valuehead(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
if getattr(model.config, "model_type", None) == "llava":
setattr(model, "lm_head", model.language_model.get_output_embeddings())
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Tuple
import torch
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)

View File

@@ -1,5 +1,6 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
@@ -8,7 +9,7 @@ from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
@@ -20,12 +21,9 @@ if TYPE_CHECKING:
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: bool = True,
**kwargs,
):
@@ -47,10 +45,10 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self.beta = beta
self.label_smoothing = 0
self.loss_type = loss_type
self.ftx_gamma = ftx_gamma
self.beta = finetuning_args.dpo_beta
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
@@ -66,14 +64,23 @@ class CustomDPOTrainer(DPOTrainer):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
@@ -84,18 +91,27 @@ class CustomDPOTrainer(DPOTrainer):
return -all_logps
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under the given logits if loss_type != IPO.
Otherwise the average log probabilities.
"""
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model(
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False,
logits=all_logits,
labels=batch_copied["labels"],
average_log_prob=(self.loss_type == "ipo"),
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
@@ -106,9 +122,9 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
@@ -149,13 +165,13 @@ class CustomDPOTrainer(DPOTrainer):
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics

View File

@@ -2,13 +2,12 @@
from typing import TYPE_CHECKING, List, Optional
from ...data import get_dataset, split_dataset
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push, create_ref_model
from .collator import DPODataCollatorWithPadding
from .trainer import CustomDPOTrainer
@@ -25,10 +24,12 @@ def run_dpo(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DPODataCollatorWithPadding(
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
@@ -45,13 +46,10 @@ def run_dpo(
# Initialize our Trainer
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
finetuning_args=finetuning_args,
model=model,
ref_model=ref_model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
@@ -66,7 +64,7 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
# Evaluation
if training_args.do_eval:

View File

@@ -0,0 +1,4 @@
from .workflow import run_orpo
__all__ = ["run_orpo"]

View File

@@ -0,0 +1,127 @@
from collections import defaultdict
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import FinetuningArguments
class CustomORPOTrainer(DPOTrainer):
def __init__(
self,
model: Union["PreTrainedModel", "torch.nn.Module"],
finetuning_args: "FinetuningArguments",
disable_dropout: bool = True,
**kwargs,
):
if disable_dropout:
disable_dropout_in_model(model)
self.finetuning_args = finetuning_args
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
self.precompute_ref_log_probs = False
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.beta = finetuning_args.orpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss.
"""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
odds_ratio_loss = -F.logsigmoid(log_odds)
return odds_ratio_loss
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
logits=all_logits,
labels=batch["labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the ORPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
sft_loss = -chosen_logps
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()
chosen_rewards = self.beta * chosen_logps.detach()
rejected_rewards = self.beta * rejected_logps.detach()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().cpu().mean()
return batch_loss, metrics

View File

@@ -0,0 +1,69 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push
from .trainer import CustomORPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments
def run_orpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer
trainer = CustomORPOTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args),
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies", "sft_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -1,25 +1,29 @@
import math
import os
import sys
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..utils import create_custom_optimzer, create_custom_scheduler
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from datasets import Dataset
from transformers import DataCollatorWithPadding, PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
@@ -40,10 +44,53 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs,
model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
tokenizer: "PreTrainedTokenizer",
dataset: "Dataset",
data_collator: "DataCollatorWithPadding",
):
PPOTrainer.__init__(self, **kwargs)
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
log_with=training_args.report_to[0] if training_args.report_to else None,
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = self.create_optimizer(model, training_args, finetuning_args)
scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)
PPOTrainer.__init__(
self,
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
lr_scheduler=scheduler,
)
self.args = training_args
self.model_args = model_args
@@ -78,6 +125,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
@@ -205,6 +257,44 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
def create_optimizer(
self,
model: "AutoModelForCausalLMWithValueHead",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
if optimizer is None:
decay_params, nodecay_params = [], []
decay_param_names = self.get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
optimizer = optim_class(param_groups, **optim_kwargs)
return optimizer
def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(training_args, num_training_steps, optimizer)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
return lr_scheduler
@torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
@@ -269,7 +359,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)

View File

@@ -1,19 +1,15 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
from ..utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@@ -31,8 +27,9 @@ def run_ppo(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
@@ -42,45 +39,6 @@ def run_ppo(
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = create_custom_optimzer(model, training_args, finetuning_args, num_training_steps)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
# Initialize our Trainer
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
@@ -88,15 +46,12 @@ def run_ppo(
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
reward_model=reward_model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
# Training

Some files were not shown because too many files have changed in this diff Show More