146 Commits

Author SHA1 Message Date
hiyouga
3cef844079 fix setup
Former-commit-id: 7d3e7db46a5f8672dd57fa5fcc03822e175047f9
2024-04-28 03:49:13 +08:00
hiyouga
4dcd47100d fix llava rlhf
Former-commit-id: f6863cbbcbf960d6481296c6cae3e40fd70e4e14
2024-04-28 03:01:49 +08:00
hiyouga
a412b4ed4a add models to 0.7.0
Former-commit-id: 436d3754452f839c617839ab3bbaacc4a8908e19
2024-04-28 01:50:30 +08:00
hiyouga
544a6259b6 update readme
Former-commit-id: c9190fe36f511c3a5149d45c85a10b02a57fa88a
2024-04-26 23:39:19 +08:00
hiyouga
c501f377dd release v0.7.0
Former-commit-id: 45bb89cb4d26a6b3fb5360bc90ab950738fe4920
2024-04-26 23:18:00 +08:00
hiyouga
cb8b8f40cd update readme
Former-commit-id: f3d4b46338d4d484b205d0651a1fa7b2e77a1654
2024-04-26 20:09:14 +08:00
hiyouga
70bed8ad8f support Qwen1.5 110B
Former-commit-id: d6e5ecaf4109127bab24e39a0696076bceb0b37c
2024-04-26 19:59:22 +08:00
hiyouga
51f776ae2a fix llava qlora
Former-commit-id: 01c5a669f6fe598aac1758a700a7607da37db1bc
2024-04-26 18:00:23 +08:00
hiyouga
697bc20941 add llava to llamaboard
Former-commit-id: deaaff0a9de0eef9691991c99cd797461b1165cc
2024-04-26 06:41:35 +08:00
hiyouga
1480e3a88f update readme
Former-commit-id: df1155245d3f71ba4f3361d43aa662ab3b024de8
2024-04-26 05:49:26 +08:00
hoshi-hiyouga
19029d5b0f Merge pull request #3454 from hiyouga/mllm
Support fine-tuning LLaVA-1.5 MLLM @BUAADreamer 

Former-commit-id: c4195d1e26349795f7aad5c10a8a9e2abb7b64a3
2024-04-26 05:46:29 +08:00
hiyouga
7773ac0ead update readme
Former-commit-id: 41728fd74de7bec0cc6135aef9dfa3ae9fe7af73
2024-04-26 05:44:30 +08:00
hiyouga
23b881bff1 support mllm hf inference
Former-commit-id: 2c7c01282acd7ddabbb17ce3246b8dae4bc4b8cf
2024-04-26 05:34:58 +08:00
hoshi-hiyouga
10a6c395bb Merge pull request #3450 from BUAADreamer/mllm
Add Multimodal LLM Finetuning

Former-commit-id: 7cacbcfdf7391080ef43eb2b2c79a5237e6120e8
2024-04-26 05:30:30 +08:00
hoshi-hiyouga
f9a7732a1f Update preprocess.py
Former-commit-id: 0e376eab23d38b8fca05f054f3cde308756ee3b1
2024-04-26 04:10:28 +08:00
hoshi-hiyouga
c37582af02 Update aligner.py
Former-commit-id: 855489074c469f47572153df0fa1e251b187b232
2024-04-26 03:48:34 +08:00
hoshi-hiyouga
ece67f8c7f Update parser.py
Former-commit-id: 4df75e8a9a391565cc3eec69bc0ebf5d5192de61
2024-04-26 03:35:39 +08:00
hoshi-hiyouga
e1838e76fe Update loader.py
Former-commit-id: 6a5f2e2ab7304113ff71cb77aafff6a1f74831f8
2024-04-26 03:33:07 +08:00
hoshi-hiyouga
2eede9ffd6 Update workflow.py
Former-commit-id: 5b8b5b975716d539ae2fae8536f79e106aa0b566
2024-04-26 03:29:12 +08:00
hoshi-hiyouga
a6f6b406b3 Update loader.py
Former-commit-id: 72d4817a15f6916706828ea2a61d808183c23773
2024-04-26 03:22:40 +08:00
hoshi-hiyouga
279439abbe update hparam name
Former-commit-id: 9941adfbf06db37f8ba32c4555f6e58e27188aaf
2024-04-26 02:49:39 +08:00
hoshi-hiyouga
13117b69d7 delete llava template (use vicuna)
Former-commit-id: 420e64970e5a0e45453041927e0366ee8beb73d5
2024-04-26 02:20:47 +08:00
BUAADreamer
5d03ac642d modify some bug
Former-commit-id: 593b7b004df74bd24361c9883401a656c08fb589
2024-04-25 22:59:46 +08:00
BUAADreamer
5062ee547e modify some style
Former-commit-id: 1291c7ee39361dd75247c67f04dcf20b472faf83
2024-04-25 22:40:53 +08:00
BUAADreamer
59817c27e3 modify some style
Former-commit-id: d578a90cefa7ec813355795bdd6ead5ee558ce26
2024-04-25 22:40:25 +08:00
BUAADreamer
759bee48d2 merge some func
Former-commit-id: 3085107c44715e4b2ca96d73b20d90c172b95219
2024-04-25 22:35:17 +08:00
BUAADreamer
514ffafc12 modify some style
Former-commit-id: 053062abc007014a7fde95c5ae9f4d859893d8ad
2024-04-25 22:04:09 +08:00
BUAADreamer
8b2a735c14 modify some style
Former-commit-id: b016e6a671a2f228f0bdd9b8d5995b4669609655
2024-04-25 21:58:18 +08:00
BUAADreamer
10d59e9e4a make dataset script
Former-commit-id: 25892f958da14976025a775febf628cd0e0a3d85
2024-04-25 21:32:01 +08:00
BUAADreamer
058ed5e607 modify style
Former-commit-id: c1f1df99e4dc3d0aadf1207b4e9a16218187fd5a
2024-04-25 21:29:50 +08:00
BUAADreamer
110c2ce2a5 modify style
Former-commit-id: 3bffc1e1b8bcc4582cebea06d35e5146163c7bec
2024-04-25 21:27:48 +08:00
BUAADreamer
c425436676 modify style
Former-commit-id: 54b713d0c4ffdfc6a7faeb14471b58bb1cd8acf5
2024-04-25 21:15:16 +08:00
BUAADreamer
266fe908e3 Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: c4bb5af69c5bbf0b1ea044cbb2b18acddc6733ac
2024-04-25 21:08:40 +08:00
BUAADreamer
dbd905438b add some
Former-commit-id: 8d035a849c4a441d457791aab073861adf69a09f
2024-04-25 21:08:32 +08:00
hoshi-hiyouga
d64c87f928 Merge pull request #3449 from hiyouga/mllm
add webui backend option

Former-commit-id: 372fcedef40b79fe8bd3932c06c720f2a03db6e6
2024-04-25 20:58:16 +08:00
hiyouga
29eebef696 add webui backend option
Former-commit-id: 3764586cb3ed64fe376d0ae420ff5690c28459e2
2024-04-25 20:49:23 +08:00
hiyouga
7bfbcb1fe3 vllm + lora support
Former-commit-id: 8cb86ba355195f5d6dcb95ee6b6b7203463a34db
2024-04-25 20:24:31 +08:00
BUAADreamer
9b210cf4b3 rm some
Former-commit-id: 2c85b4fabbebd8b51eee53f5d29184d4a6e97569
2024-04-25 20:09:43 +08:00
BUAADreamer
f74e640565 Merge branch 'hiyouga:main' into main
Former-commit-id: 131d0bcd554dedd794add7eb3d7b1201cac80e7c
2024-04-25 20:02:50 +08:00
BUAADreamer
d1d08d066a merge data part to the text stream
Former-commit-id: 80537d580119d9d5a06ab236a5284aaae2f83b5b
2024-04-25 19:58:47 +08:00
hiyouga
6be321b5da fix #3374
Former-commit-id: 0097d7968b3b570e1705caff26f42d9ed71ad974
2024-04-25 19:56:49 +08:00
BUAADreamer
3c792174db merge data part to the text stream
Former-commit-id: 7ee20286d9bcc2d5378bfd6bb02cd3648396d873
2024-04-25 19:19:59 +08:00
hiyouga
9aeb88c426 add export_device in webui #3333
Former-commit-id: 30ebd3652809d73941e0a5e4a8be11d989faf98d
2024-04-25 19:02:32 +08:00
BUAADreamer
00e2a272ef merge model part to the text stream
Former-commit-id: b6fcb832ddaed4647d6f2b926f3dfccd47f3ea84
2024-04-25 08:20:41 +08:00
BUAADreamer
5142349661 remove error
Former-commit-id: 2bcd1c7dc3595f17ae4e2c4475196cc2d03d0e75
2024-04-25 01:01:59 +08:00
BUAADreamer
0e3cc52327 remove conflicts
Former-commit-id: e5750ee202eb67cf5fc54f464548e2eb43d00900
2024-04-25 00:56:06 +08:00
BUAADreamer
6c1db2d012 remove conflicts
Former-commit-id: f8b637eb76cba7ec229e2978068805ad1cca8adb
2024-04-25 00:34:22 +08:00
BUAADreamer
12c51655ce add llava and instructblip
Former-commit-id: 142fb6f4541a1acfefe66ff2574dabde53b00c06
2024-04-25 00:22:43 +08:00
hiyouga
36be12a3b7 update tool template
Former-commit-id: c72a1981859818c257c5271d32e03c9d3c344206
2024-04-25 00:21:34 +08:00
hiyouga
21fac4c98c fix log level
Former-commit-id: 8d21302f6201b3f33c10f61f3559bd95be3363c2
2024-04-24 23:42:59 +08:00
hiyouga
83404c4fa9 support new special token #3420
Former-commit-id: f5c6a47f5193ab3a6c137580992bdcce0b31fdd5
2024-04-24 23:39:31 +08:00
hoshi-hiyouga
12f852b8d4 fix phi template
Former-commit-id: 14a1ff665eaebfc618229efbe96f09848d52faec
2024-04-24 13:55:14 +08:00
hoshi-hiyouga
a88873116a fix webchatmodel
Former-commit-id: dc6d8b5dc42c363dd180aaf90c9a2f2d0cce6725
2024-04-24 13:54:21 +08:00
hoshi-hiyouga
7cfcd69c64 fix inference in llamaboard
Former-commit-id: 5e631915157083b61e2d5a183e0c91f2d11f416e
2024-04-24 13:53:39 +08:00
hiyouga
a5eabbe933 add olmo 1.7
Former-commit-id: 86a3fb3a141d2702b15af08df36ffcf9b3d6de14
2024-04-24 05:50:50 +08:00
hiyouga
aa25716a5d add dbrx and jamba models
Former-commit-id: ce35c80b4b00152185285d6064939803d14487f0
2024-04-24 05:39:52 +08:00
hiyouga
94c8219575 fix bug
Former-commit-id: 38e164fe4aaea6f0baf121a720291ca42643ba8c
2024-04-24 05:21:18 +08:00
hiyouga
ad24a2a0c9 fix bug
Former-commit-id: 271c24d2c82d645fa9072e6de94ca38f20411537
2024-04-24 05:10:07 +08:00
hiyouga
c05027d14a remove redundant code
Former-commit-id: 4a7a7ad2bcdc493458084f5f3d384239228b7d5a
2024-04-24 05:02:18 +08:00
hiyouga
5420905a2e support unsloth generate
Former-commit-id: 0ef1ad9f505dba71db9342f524cc3a7565e5e09e
2024-04-24 04:46:53 +08:00
hiyouga
03f2e3284a refactor patcher
Former-commit-id: 263cfe1294f5c3188f5e8d65791f35ee0d87315a
2024-04-24 03:02:23 +08:00
hiyouga
d2bb1b3a6b reenable sdpa and fast tok by default
Former-commit-id: 9e00902dbedc71d55743d1bf237843506a557891
2024-04-24 02:18:44 +08:00
hiyouga
35c4a2c212 fix #3347 #3387
Former-commit-id: c253c18185a29b59190f3e0ed236c2bb4c788085
2024-04-24 01:30:16 +08:00
hiyouga
1e4010a1fb support phi-3
Former-commit-id: 7e8ffa9beee3893e051ceeade443bd56c4a07b1c
2024-04-24 00:28:53 +08:00
BUAADreamer
1451297c78 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: 67800c565b086f362b8cf131b0c9babaa7a7ebc7
2024-04-23 19:22:42 +08:00
BUAADreamer
0b99b13786 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: b78b5f290aa38a7454e101ee9703fb6fac5064ac
2024-04-23 18:47:03 +08:00
BUAADreamer
f5edbf2b49 Merge branch 'hiyouga:main' into main
Former-commit-id: 6287d1b789c631205c1033adf036e28deaef4167
2024-04-23 18:46:12 +08:00
BUAADreamer
ab6dc0ea30 add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: a730f89a972f1a9d37c718c716f199cb8d4903b2
2024-04-23 18:45:43 +08:00
hiyouga
79d34ce0f3 update examples
Former-commit-id: 8bf55682cdfbbdca0f01073eac0084c20a6a09d1
2024-04-23 18:29:46 +08:00
hiyouga
1d2e372a8e update readme
Former-commit-id: d4eaee262a64e716ce475dc4eb18d8d9697d8dd8
2024-04-22 17:09:17 +08:00
hiyouga
f6a53d83c8 update readme
Former-commit-id: 3eab580703ee01a0d2d75e7f01df5165af551386
2024-04-22 00:51:35 +08:00
hiyouga
4ec56dd958 update readme
Former-commit-id: fdca136309709e43d75a831252b9375a5a99635a
2024-04-22 00:42:25 +08:00
hiyouga
ba06eb65ca update readme and examples
Former-commit-id: 27dd9bf201c24f7804811398bc2758966ec78432
2024-04-22 00:37:32 +08:00
hiyouga
be716972fe remove extras
Former-commit-id: d67e972f8c3d5273e589c8c85c0a1620f59785c5
2024-04-22 00:35:41 +08:00
hiyouga
719585a128 update readme
Former-commit-id: 3a8c17907c71f46b1b37501e2afdc99ad89fb4bc
2024-04-22 00:21:01 +08:00
hiyouga
348f29aa50 set dev version
Former-commit-id: b9557887d7506ff57b2b2bf490092aac4e4becf0
2024-04-21 23:14:30 +08:00
hiyouga
c8fe3f544b release v0.6.3
Former-commit-id: 947572af8de201669598f54735f35b50bb719d71
2024-04-21 23:13:23 +08:00
hiyouga
0f1ad7140f fix #3366
Former-commit-id: dc20237455c36de44f8922539d7dfadd8bedb12f
2024-04-21 21:34:25 +08:00
hiyouga
233e167f68 fix optimizers
Former-commit-id: f811eee2fa12a89a55a9c5d3a05a1521b4347727
2024-04-21 20:40:54 +08:00
hiyouga
1d341dcd83 fix #3365
Former-commit-id: 415ce41e8fa887e980e5bd575c8e95bd4076b90b
2024-04-21 19:20:18 +08:00
hiyouga
d16561e7a4 fix bug in galore optimizer
Former-commit-id: c05ac23261a5a8ba893c2918a43dc7777307407b
2024-04-21 18:53:22 +08:00
hiyouga
f8e219dc81 fix mod stuff
Former-commit-id: cf3988226e6398c67bb2955578e436fc505aa5c5
2024-04-21 18:11:10 +08:00
hoshi-hiyouga
3365cc8cf0 Merge pull request #3338 from astramind-ai/main
Adding Mixture of Depth

Former-commit-id: 4da2ece53353b63e672ff529d6beba41ff710c14
2024-04-21 18:05:52 +08:00
hoshi-hiyouga
3a5e68b7d9 fix #3348
Former-commit-id: aa5e921c00f60074eceb2f9d4d8837cc713edba6
2024-04-20 10:34:09 +08:00
hiyouga
0cb596fee1 add dpo mix dataset
Former-commit-id: 6def3f8bfa51b2d9d73af112352ce07db972e4c9
2024-04-20 01:31:38 +08:00
hiyouga
b3b5b530d1 fix #3352
Former-commit-id: f315f8e8ec916b82bac94a159e55839ff155c6b5
2024-04-19 22:40:01 +08:00
hiyouga
9225c15c88 fix llama3 template
Former-commit-id: 20e95250168fbe081c779b2e1ff23f5df3ce02f7
2024-04-19 15:46:51 +08:00
Marco
abd9fed445 fix small typo
Former-commit-id: 5638a03cd0cf8119ff366b3b3e303b5a2351b065
2024-04-18 20:33:29 +02:00
Marco
44cda2eece Added Mixture of Depths
Former-commit-id: 75dd98b9abc847e22cb263c17ebcd2ca5dd98345
2024-04-18 20:31:24 +02:00
hoshi-hiyouga
8397808d1d support llama3
Former-commit-id: c1eabb751a5fd73b710714451b146732e0ed4558
2024-04-19 01:13:50 +08:00
hiyouga
9e1bd6420d fix #3324
Former-commit-id: 5e710c4ac331f3400534d33b2646c4108c898d98
2024-04-18 15:34:45 +08:00
hiyouga
619264c854 tiny fix
Former-commit-id: 86399ca8c06273c42c2b184664ae25d3405b3bf6
2024-04-18 00:22:17 +08:00
hiyouga
1ebac62e3d update readme
Former-commit-id: a49112a74339ba77bfec53f7870e821fe148db2c
2024-04-17 23:40:49 +08:00
hiyouga
ce9bdb3509 add mixtral 8x22B models
Former-commit-id: eccbeecff0909e1fa124b5439ffbbfbc5607e1d6
2024-04-17 23:35:59 +08:00
hiyouga
0c8d6369ac add CodeQwen models
Former-commit-id: 9f6094241391f8f717818c8ba94e11d1791b4a5c
2024-04-17 23:27:22 +08:00
hiyouga
bee796f6b5 fix #3316
Former-commit-id: 7395e9e90a209228ff563ab54319955608850fc3
2024-04-17 22:54:34 +08:00
hiyouga
9f6349a333 fix #3317
Former-commit-id: 7dce1763be4374cf616d96db95ae964ff510a9d6
2024-04-17 22:17:19 +08:00
hiyouga
171a029c5e lint
Former-commit-id: 917d65ce65024d17a5030bc57083a427cfae16d7
2024-04-16 18:21:09 +08:00
hoshi-hiyouga
eaefaa0fe0 Merge pull request #3291 from codemayq/main
support for previewing custom dataset in directory format

Former-commit-id: 40d89152282101a7c08f53e72c2ad7124a0595f3
2024-04-16 18:12:09 +08:00
hiyouga
d301f0a64b Update parser.py
Former-commit-id: 92c2133896c20054db86dd53508c982e39bd5ca0
2024-04-16 18:09:31 +08:00
hiyouga
0a1578e4e3 update readme and gradio version
Former-commit-id: 4029b60ddcbd15b5354503c51178f0f5e7e9aedf
2024-04-16 18:09:16 +08:00
hiyouga
a4167fd925 support badam for all stages
Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
2024-04-16 17:44:48 +08:00
hoshi-hiyouga
42084e08ae Merge pull request #3287 from Ledzy/badam
[Feature] Add BAdam algorithm

Former-commit-id: 10a5e1e65b34b03e5ca2a41bf6ded09a3fb25f0c
2024-04-16 17:32:16 +08:00
hoshi-hiyouga
9d23f5dc89 Update utils.py
Former-commit-id: 01147536b2bb507e87e033fa696e9eb39fe96bbe
2024-04-16 17:30:12 +08:00
hoshi-hiyouga
5978427ae0 Update trainer.py
Former-commit-id: c6163be1444c00dd000f288e2f834968bd932981
2024-04-16 17:29:52 +08:00
hoshi-hiyouga
c7c216069c Update utils.py
Former-commit-id: 7edf4dbed88b8034282f14fd6e0cb6f7f9e5f805
2024-04-16 17:29:30 +08:00
hoshi-hiyouga
cde9d1b917 Update patcher.py
Former-commit-id: 494e6a1e05b38f5ff61d83327303614f53c92e64
2024-04-16 17:29:19 +08:00
hoshi-hiyouga
96213f04b0 Update adapter.py
Former-commit-id: 8f7b75b26f020d8ae85baab7b082475c3bfeb512
2024-04-16 17:28:12 +08:00
hoshi-hiyouga
7ecea08b9b Update parser.py
Former-commit-id: 898239883afc79f03abd0dc276eef901662a9591
2024-04-16 17:27:25 +08:00
hoshi-hiyouga
191971865d Update parser.py
Former-commit-id: 2f3da8169d18b026760cc0ac7dd6141bdd08c932
2024-04-16 17:27:02 +08:00
hoshi-hiyouga
ff4f587dd9 Update finetuning_args.py
Former-commit-id: 3a23d900aea74078f0bc8cf73fac860a4ce3df67
2024-04-16 17:26:30 +08:00
hoshi-hiyouga
de728d0371 Update sft.sh
Former-commit-id: 2b4b1562e91bbb02e345e71b7721da9333c0791b
2024-04-16 17:25:40 +08:00
hoshi-hiyouga
d08e09642d Update requirements.txt
Former-commit-id: 1e45537ca0bb4d49b4147df01122e365b3d617e4
2024-04-16 17:10:17 +08:00
hoshi-hiyouga
351493b183 Update setup.py
Former-commit-id: 5df30ea166aff29d48ff83a22ac6ef1611ce3e35
2024-04-16 17:10:02 +08:00
Jonery
86ab47e121 remove badam from core requirements
Former-commit-id: fa5898944a3867ac5108dd0d579ca0677c87d3d6
2024-04-16 12:25:50 +08:00
Jonery
6dd6b3e396 resolve gradient checkpointing issue.
Former-commit-id: 6df9135d063bb6102f0cbcdf0d702076f5febbae
2024-04-16 12:05:27 +08:00
codingma
5f1418a68b add check
Former-commit-id: 008f6498977c243c80e87242f05c9cf9573541ac
2024-04-16 10:56:39 +08:00
codingma
7b97a79efc support for previewing custom dataset in directory format
Former-commit-id: 501cff38c819f06f15194907ce7e052d5f28025a
2024-04-16 10:43:14 +08:00
hiyouga
ce4f653121 add empty template
Former-commit-id: a325ffa8a668bec354d2636683806acef105e196
2024-04-16 03:10:02 +08:00
hiyouga
b053c6454e update readme
Former-commit-id: 8f233745c3aa7a6ef57f275bec80ee731ff76de3
2024-04-16 02:36:54 +08:00
hiyouga
ebf0f4a77c update readme
Former-commit-id: f9a246572c1ec0e4b36bff237c6523ce629b7000
2024-04-16 02:35:36 +08:00
hiyouga
efa808069a support unsloth 2024.4
Former-commit-id: 14a83f8bc4fe44783252378fce59198194a96bb8
2024-04-16 00:25:03 +08:00
hiyouga
b5c5283dd6 add codegemma
Former-commit-id: 9324176525c2eda22962b0ca1895009b6237e6e3
2024-04-16 00:11:15 +08:00
hiyouga
b638c65519 support cohere commandR #3184
Former-commit-id: e077c36872740f6b2ac255aee9da6c4c70f28977
2024-04-15 23:26:42 +08:00
Jonery
d4d471450f Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
2024-04-15 23:15:27 +08:00
hoshi-hiyouga
3144bdec2c Merge pull request #3254 from marko1616/feature/Add-support-for-CohereForAI/c4ai-command-r-plus
Add template&support for c4ai-command-r/plus (tested)

Former-commit-id: 41d39ec4889abad050820bf153133ac3a11228a3
2024-04-15 22:59:35 +08:00
hoshi-hiyouga
c6d6c4c209 Update template.py
Former-commit-id: 00b8be7dafa65e13b344724a8d3855919ee4f631
2024-04-15 22:58:01 +08:00
hoshi-hiyouga
f5f1589662 Update constants.py
Former-commit-id: 39199f712aa7b7a1c66080d9c84651fd2eb0b425
2024-04-15 22:56:55 +08:00
hiyouga
276f2cb24e update examples
Former-commit-id: 369294b31c8a03a1cafcee83eb31a817007d3c49
2024-04-15 22:14:34 +08:00
marko1616
952b785bb3 change default_system accroding to official template
Former-commit-id: 7ad9029c5e77a87a7c324b8f90b4f80a31a5c78b
2024-04-15 20:45:46 +08:00
marko1616
72dd676208 Revert "Add support for function call(Not strictly following origin)"
This reverts commit dfaa31e991 [formerly 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6].


Former-commit-id: fac9cc6e01dd8f3bc449b656804476e1871326f0
2024-04-15 20:27:09 +08:00
marko1616
dfaa31e991 Add support for function call(Not strictly following origin)
Former-commit-id: 44f3ada4e394c06b0d972329ed2a62d2be2ea0c6
2024-04-15 20:16:52 +08:00
hoshi-hiyouga
86556b1c74 Merge pull request #3261 from khazic/main
Added specimens for single-card full parameter prediction

Former-commit-id: 60df2a9519fbd8215c3afacc831b0cc89006457a
2024-04-15 16:30:57 +08:00
hoshi-hiyouga
0c80751e87 Merge pull request #3276 from liu-zichen/fix_mixtral
fix: turn on output_router_logits of mixtral
Former-commit-id: 07bbaf5c67d00a152e5304e81b15fd9189e7bb99
2024-04-15 15:38:16 +08:00
hiyouga
9338f878a3 fix #3273
Former-commit-id: 3b20c89b342a068356ffc29c3724b645775c65db
2024-04-15 15:32:58 +08:00
liuzc
fde3d91242 fix: mixtral output_router_logits
Former-commit-id: ab3171ea97ec968b972287287ef9ee2502c6d37c
2024-04-15 12:11:49 +08:00
khazic
19adfb88a9 Upgrade README.md
Former-commit-id: 697f768d7185789ee054c94f4f161a65b8a505bc
2024-04-13 20:50:49 +08:00
khazic
daaafa900a Added specimens for single-card full parameter prediction
Former-commit-id: d8d4fb9fa4b0e1950a453682e5e186f34f085dee
2024-04-13 20:45:19 +08:00
marko1616
0dcc9e0bca Typo fix
Former-commit-id: 607625497738b2c8be736be7b0bd5c6f4cbaad5e
2024-04-13 17:30:21 +08:00
marko1616
aeec78b35c Typo fix
Former-commit-id: 51b1e49e288e66c1b0c24ac070201c988fb2a389
2024-04-13 07:52:11 +08:00
marko1616
c991654cb4 Add c4ai-command-r-plus link
Former-commit-id: acaf953ca46eca8fb378067f4ada133654e4f088
2024-04-13 07:32:40 +08:00
marko1616
f328413646 Add template&support(Not tested)
Former-commit-id: 60bb60c4dc30a9641ddb57a44ef126f0768566c4
2024-04-13 04:31:33 +08:00
hiyouga
106a0104da fix #3247
Former-commit-id: bb67c66f80627805b585d157ba807c0ce378d3f2
2024-04-12 17:41:33 +08:00
hiyouga
5486ea09e3 fix model card
Former-commit-id: 920e7149bf2b559c9829aa4b11cfb6d00bbb2f9e
2024-04-12 17:11:59 +08:00
hiyouga
31bbbb6d13 fix #3238
Former-commit-id: 4d7e81ab4722d13bec6ca1af141f94bdc74d0883
2024-04-12 14:28:11 +08:00
hiyouga
1a77de82fa set dev version
Former-commit-id: f6cc76571d2c789675883a18e0db3d0c61f33808
2024-04-11 20:27:34 +08:00
98 changed files with 2540 additions and 962 deletions

130
README.md
View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-28-green)](#projects-using-llama-factory)
[![Citation](https://img.shields.io/badge/citation-34-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -43,10 +43,10 @@ Choose your path:
## Features
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -68,14 +68,24 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See `examples/lora_single_gpu/sft_mllm.sh` for usage.
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
<details><summary>Full Changelog</summary>
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage.
<details><summary>Full Changelog</summary>
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
@@ -102,7 +112,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
@@ -126,32 +136,38 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
>
> Remember to use the **SAME** template in training and inference.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
@@ -222,6 +238,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -241,6 +258,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details>
@@ -275,16 +293,15 @@ huggingface-cli login
\* *estimated*
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B |
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## Getting Started
@@ -305,7 +322,7 @@ cd LLaMA-Factory
pip install -e .[metrics]
```
Extra dependencies available: deepspeed, metrics, unsloth, galore, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
Extra dependencies available: deepspeed, metrics, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
<details><summary>For Windows users</summary>
@@ -319,7 +336,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
</details>
### LLaMA Board GUI
### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
> [!IMPORTANT]
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
@@ -328,9 +345,20 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
```bash
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
python src/train_web.py # or python -m llmtuner.webui.interface
```
<details><summary>For Alibaba Cloud users</summary>
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
```bash
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
</details>
#### Use Docker
```bash
@@ -360,7 +388,7 @@ docker compose -f ./docker-compose.yml up -d
</details>
### Command Line Interface
### Train with Command Line Interface
See [examples/README.md](examples/README.md) for usage.
@@ -370,13 +398,13 @@ Use `python src/train_bash.py -h` to display arguments description.
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--template mistral \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \
--infer_backend vllm \
--vllm_enforce_eager
```
### Use ModelScope Hub
### Download from ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
@@ -384,7 +412,7 @@ If you have trouble with downloading models and datasets from Hugging Face, you
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
```
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `modelscope/Llama-2-7b-ms`.
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
## Projects using LLaMA Factory
@@ -413,8 +441,14 @@ If you have a project that should be incorporated, please contact via email or c
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
@@ -427,7 +461,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation

View File

@@ -5,13 +5,13 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-28-green)](#使用了-llama-factory-的项目)
[![Citation](https://img.shields.io/badge/citation-34-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
👋 加入我们的[微信群](assets/wechat.jpg)。
@@ -23,7 +23,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **本地机器**:请见[如何使用](#如何使用)
## 目录
@@ -43,10 +43,10 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -68,14 +68,24 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 `examples/lora_single_gpu/sft_mllm.sh`
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
<details><summary>展开日志</summary>
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora`
<details><summary>展开日志</summary>
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`
@@ -102,7 +112,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2。
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn fa2` 参数以启用 FlashAttention-2。
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
@@ -126,32 +136,38 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
> [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以得到更好的效果
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat模型请务必使用**对应的模板**。
>
> 请务必在训练和推理时使用**完全一致**的模板。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
@@ -222,6 +238,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -241,6 +258,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details>
@@ -275,16 +293,15 @@ huggingface-cli login
\* *估算值*
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B |
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## 如何使用
@@ -305,7 +322,7 @@ cd LLaMA-Factory
pip install -e .[metrics]
```
可选的额外依赖项deepspeed、metrics、unsloth、galore、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
可选的额外依赖项deepspeed、metrics、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
<details><summary>Windows 用户指南</summary>
@@ -319,18 +336,29 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
</details>
### LLaMA Board 可视化界面
### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
> [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行分布式训练。
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。
#### 使用本地环境
```bash
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
python src/train_web.py # 或 python -m llmtuner.webui.interface
```
<details><summary>阿里云用户指南</summary>
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
```bash
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
</details>
#### 使用 Docker
```bash
@@ -360,23 +388,23 @@ docker compose -f ./docker-compose.yml up -d
</details>
### 命令行接口
### 利用命令行接口训练
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
使用 `python src/train_bash.py -h` 查看参数文档。
您可以执行 `python src/train_bash.py -h` 查看参数文档。
### 使OpenAI 风格 API 和 vLLM 部署
### 用 vLLM 部署 OpenAI API
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--template mistral \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \
--infer_backend vllm \
--vllm_enforce_eager
```
### 使用魔搭社区
### 魔搭社区下载
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
@@ -384,11 +412,11 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
```
`--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `modelscope/Llama-2-7b-ms`
`--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
## 使用了 LLaMA Factory 的项目
如果您有项目希望添加至述列表,请通过邮件联系或者创建一个 PR。
如果您有项目希望添加至述列表,请通过邮件联系或者创建一个 PR。
<details><summary>点击显示</summary>
@@ -413,8 +441,14 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
@@ -427,7 +461,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2/LLaVA-1.5](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用

View File

@@ -18,7 +18,8 @@ If you are using a custom dataset, please provide your dataset definition in the
"history": "the column name in the dataset containing the histories. (default: None)",
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)"
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",

View File

@@ -18,7 +18,8 @@
"history": "数据集代表历史对话的表头名称默认None",
"messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None"
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",

View File

@@ -1 +1 @@
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
a97cf9475291591843976554878568e046d8a46d

View File

@@ -1,5 +1,6 @@
import os
import json
import os
import datasets
@@ -22,31 +23,19 @@ _URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0
class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self):
features = datasets.Features({
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
features = datasets.Features(
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f:
@@ -58,7 +47,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
query = prompt[human_idx + 6 : assist_idx].strip()
prompt = prompt[:human_idx].strip()
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
@@ -67,8 +56,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:")
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
old_query = prompt[human_idx + 6 : assist_idx].strip()
old_resp = prompt[assist_idx + 10 :].strip()
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else:

View File

@@ -1,7 +1,8 @@
import json
import datasets
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
@@ -11,34 +12,24 @@ _URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
})
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))

View File

@@ -1,8 +1,10 @@
import os
import json
import datasets
import os
from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = ""
@@ -14,50 +16,37 @@ _URLS = {
_URL + "harmless-base/train.jsonl.gz",
_URL + "helpful-base/train.jsonl.gz",
_URL + "helpful-online/train.jsonl.gz",
_URL + "helpful-rejection-sampled/train.jsonl.gz"
_URL + "helpful-rejection-sampled/train.jsonl.gz",
],
"test": [
_URL + "harmless-base/test.jsonl.gz",
_URL + "helpful-base/test.jsonl.gz",
_URL + "helpful-online/test.jsonl.gz",
_URL + "helpful-rejection-sampled/test.jsonl.gz"
]
_URL + "helpful-rejection-sampled/test.jsonl.gz",
],
}
class HhRlhfEn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Sequence(datasets.Value("string")),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
})
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"output": datasets.Sequence(datasets.Value("string")),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download_and_extract(_URLS)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_path["train"]
}
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepaths": file_path["test"]
}
)
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
]
def _generate_examples(self, filepaths: List[str]):
@@ -70,12 +59,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
rejected = data["rejected"]
assist_idx = rejected.rfind("\n\nAssistant: ")
r_reject = rejected[assist_idx+13:].strip()
r_reject = rejected[assist_idx + 13 :].strip()
assist_idx = chosen.rfind("\n\nAssistant: ")
r_accept = chosen[assist_idx+13:].strip()
r_accept = chosen[assist_idx + 13 :].strip()
human_idx = chosen.rfind("\n\nHuman: ")
query = chosen[human_idx+9:assist_idx].strip()
query = chosen[human_idx + 9 : assist_idx].strip()
prompt = chosen[:human_idx]
history = []
@@ -83,16 +72,12 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
assist_idx = prompt.rfind("\n\nAssistant: ")
human_idx = prompt.rfind("\n\nHuman: ")
if human_idx != -1:
old_query = prompt[human_idx+9:assist_idx].strip()
old_resp = prompt[assist_idx+13:].strip()
old_query = prompt[human_idx + 9 : assist_idx].strip()
old_resp = prompt[assist_idx + 13 :].strip()
history.insert(0, (old_query, old_resp))
else:
break
prompt = prompt[:human_idx]
yield key, {
"instruction": query,
"output": [r_accept, r_reject],
"history": history
}
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
key += 1

View File

@@ -1,8 +1,10 @@
import os
import json
import datasets
import os
from typing import List
import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
@@ -24,31 +26,19 @@ _BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jso
class UltraChat(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self):
features = datasets.Features({
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
features = datasets.Features(
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_paths
}
)
]
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths:
@@ -56,7 +46,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
for row in f:
try:
data = json.loads(row)
except:
except Exception:
continue
key: int = data["id"]
content: List[str] = data["data"]
@@ -64,8 +54,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
content.pop(-1)
if len(content) < 2:
continue
conversations = [{
"from": "human" if i % 2 == 0 else "gpt",
"value": content[i]
} for i in range(len(content))]
conversations = [
{"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
]
yield key, {"conversations": conversations}

View File

@@ -3,41 +3,48 @@ We provide diverse examples about fine-tuning LLMs.
```
examples/
├── lora_single_gpu/
│ ├── pretrain.sh: Do pre-training
│ ├── sft.sh: Do supervised fine-tuning
│ ├── reward.sh: Do reward modeling
│ ├── ppo.sh: Do PPO training
│ ├── dpo.sh: Do DPO training
│ ├── orpo.sh: Do ORPO training
│ ├── pretrain.sh: Do continuous pre-training using LoRA
│ ├── sft.sh: Do supervised fine-tuning using LoRA
│ ├── reward.sh: Do reward modeling using LoRA
│ ├── ppo.sh: Do PPO training using LoRA
│ ├── dpo.sh: Do DPO training using LoRA
│ ├── orpo.sh: Do ORPO training using LoRA
│ ├── sft_mllm.sh: Do supervised fine-tuning on multimodal data using LoRA
│ ├── prepare.sh: Save tokenized dataset
│ └── predict.sh: Do batch predict
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
├── qlora_single_gpu/
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models
│ ├── awq.sh: Fine-tune 4-bit AWQ models
│ └── aqlm.sh: Fine-tune 2-bit AQLM models
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models using QLoRA
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models using QLoRA
│ ├── awq.sh: Fine-tune 4-bit AWQ models using QLoRA
│ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
├── lora_multi_gpu/
│ ├── single_node.sh: Fine-tune model with Accelerate on single node
── multi_node.sh: Fine-tune model with Accelerate on multiple nodes
│ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
│ └── ds_zero3.sh: Fine-tune model with DeepSpeed ZeRO-3 using LoRA (weight sharding)
├── full_multi_gpu/
│ ├── single_node.sh: Fine-tune model with DeepSpeed on single node
── multi_node.sh: Fine-tune model with DeepSpeed on multiple nodes
│ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
│ └── predict.sh: Do parallel batch predict and compute BLEU and ROUGE scores after full tuning
├── merge_lora/
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
│ └── quantize.sh: Quantize fine-tuned model with AutoGPTQ
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
├── inference/
│ ├── cli_demo.sh: Launch a command line interface
│ ├── api_demo.sh: Launch an OpenAI-style API
│ ├── web_demo.sh: Launch a web interface
│ └── evaluate.sh: Evaluate model on the MMLU benchmark
│ ├── cli_demo.sh: Chat with fine-tuned model in the CLI with LoRA adapters
│ ├── api_demo.sh: Chat with fine-tuned model in an OpenAI-style API with LoRA adapters
│ ├── web_demo.sh: Chat with fine-tuned model in the Web browser with LoRA adapters
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
└── extras/
├── galore/
│ └── sft.sh: Fine-tune model with GaLore
├── badam/
│ └── sft.sh: Fine-tune model with BAdam
├── loraplus/
│ └── sft.sh: Fine-tune model with LoRA+
│ └── sft.sh: Fine-tune model using LoRA+
├── mod/
│ └── sft.sh: Fine-tune model using Mixture-of-Depths
├── llama_pro/
│ ├── expand.sh: Expand layers in the model
│ └── sft.sh: Fine-tune expanded model
│ └── sft.sh: Fine-tune the expanded model
└── fsdp_qlora/
└── sft.sh: Fine-tune quantized model with FSDP
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
```

View File

@@ -1,43 +1,50 @@
我们提供了多样化的示例脚本。
我们提供了多样化的大模型微调示例脚本。
```
examples/
├── lora_single_gpu/
│ ├── pretrain.sh: 进行预训练
│ ├── sft.sh: 进行指令监督微调
│ ├── reward.sh: 进行奖励模型训练
│ ├── ppo.sh: 进行 PPO 训练
│ ├── dpo.sh: 进行 DPO 训练
│ ├── orpo.sh: 进行 ORPO 训练
│ ├── pretrain.sh: 基于 LoRA 进行增量预训练
│ ├── sft.sh: 基于 LoRA 进行指令监督微调
│ ├── reward.sh: 基于 LoRA 进行奖励模型训练
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
│ ├── sft_mllm.sh: 基于 LoRA 进行多模态指令监督微调
│ ├── prepare.sh: 保存预处理后的数据集
│ └── predict.sh: 进行批量预测
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
├── qlora_single_gpu/
│ ├── bitsandbytes.sh: 微调 4/8 比特 BNB 模型
│ ├── gptq.sh: 微调 4/8 比特 GPTQ 模型
│ ├── awq.sh: 微调 4 比特 AWQ 模型
│ └── aqlm.sh: 微调 2 比特 AQLM 模型
│ ├── bitsandbytes.sh: 基于 QLoRA 微调 4/8 比特 BNB 模型
│ ├── gptq.sh: 基于 QLoRA 微调 4/8 比特 GPTQ 模型
│ ├── awq.sh: 基于 QLoRA 微调 4 比特 AWQ 模型
│ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
├── lora_multi_gpu/
│ ├── single_node.sh: 使用 Accelerate 进行单节点训练
── multi_node.sh: 使用 Accelerate 进行多节点训练
│ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
│ └── ds_zero3.sh: 使用 DeepSpeed ZeRO-3 进行 LoRA 训练(拆分权重)
├── full_multi_gpu/
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点训练
── multi_node.sh: 使用 DeepSpeed 进行多节点训练
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
│ └── predict.sh: 基于全量训练进行多卡批量预测并计算 BLEU 和 ROUGE 分数
├── merge_lora/
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
│ └── quantize.sh: 使用 AutoGPTQ 量化模型
│ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型
├── inference/
│ ├── cli_demo.sh: 启动命令行推理接口
│ ├── api_demo.sh: 启动 OpenAI 风格 API
│ ├── web_demo.sh: 启动浏览器推理接口
│ └── evaluate.sh: 在 MMLU 数据集上评测模型
│ ├── cli_demo.sh: 启动 LoRA 模型的命令行推理接口
│ ├── api_demo.sh: 启动 LoRA 模型的 OpenAI 风格 API
│ ├── web_demo.sh: 启动 LoRA 模型的浏览器推理接口
│ └── evaluate.sh: 在 MMLU/CMMLU/C-Eval 数据集上评测 LoRA 模型
└── extras/
├── galore/
│ └── sft.sh: 使用 GaLore 训练模型
├── badam/
│ └── sft.sh: 使用 BAdam 训练模型
├── loraplus/
│ └── sft.sh: 使用 LoRA+ 训练模型
├── mod/
│ └── sft.sh: 使用深度混合训练模型
├── llama_pro/
│ ├── expand.sh: 扩展模型中的层
│ └── sft.sh: 训练扩展后的模型
└── fsdp_qlora/
└── sft.sh: 使用 FSDP 微调量化模型
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
```

View File

@@ -9,7 +9,7 @@ main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 16 # the number of GPUs in all nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []

View File

@@ -9,7 +9,7 @@ main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 16 # the number of GPUs in all nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []

View File

@@ -0,0 +1,35 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--use_badam \
--badam_switch_mode descending \
--badam_switch_block_every 50 \
--badam_verbose 2 \
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@@ -1,4 +1,5 @@
#!/bin/bash
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA
pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0"

View File

@@ -12,6 +12,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--galore_layerwise \
--galore_target mlp,self_attn \
--galore_rank 128 \
--galore_scale 2.0 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \

View File

@@ -9,6 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--loraplus_lr_ratio 16.0 \
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
--overwrite_cache \
--overwrite_output_dir \
@@ -29,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16 \
--loraplus_lr_ratio 16.0
--fp16

View File

@@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--mixture_of_depths convert \
--output_dir ../../../saves/LLaMA2-7B/mod/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--optim paged_adamw_8bit \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@@ -0,0 +1,20 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \
--stage sft \
--do_predict \
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type full \
--output_dir ../../saves/LLaMA2-7B/full/predict \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_eval_batch_size 1 \
--max_samples 20 \
--predict_with_generate

View File

@@ -3,7 +3,7 @@
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template vanilla \
--template fewshot \
--finetuning_type lora \
--task mmlu \
--split test \

View File

@@ -1,4 +1,5 @@
#!/bin/bash
# add `--visual_inputs True` to load MLLM
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@@ -0,0 +1,33 @@
#!/bin/bash
deepspeed --num_gpus 4 ../../src/train_bash.py \
--deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--ddp_timeout 180000000 \
--plot_loss \
--fp16

View File

@@ -1,4 +1,5 @@
#!/bin/bash
# also launch it on slave machine using slave_config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/master_config.yaml \

View File

@@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \
--stage sft \

View File

@@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--visual_inputs \
--dataset mllm_demo \
--dataset_dir ../../data \
--template vicuna \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/sft_mllm \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 100.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -1,11 +1,12 @@
#!/bin/bash
# DO NOT use quantized model or quantization_bit when merging lora weights
CUDA_VISIBLE_DEVICES= python ../../src/export_model.py \
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \
--finetuning_type lora \
--export_dir ../../models/llama2-7b-sft \
--export_size 2 \
--export_device cpu \
--export_legacy_format False

View File

@@ -1,4 +1,5 @@
#!/bin/bash
# NEED TO run `merge.sh` before using this script
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path ../../models/llama2-7b-sft \

View File

@@ -4,7 +4,7 @@ datasets>=2.14.3
accelerate>=0.27.2
peft>=0.10.0
trl>=0.8.1
gradio>=4.0.0,<=4.21.0
gradio>=4.0.0
scipy
einops
sentencepiece
@@ -15,3 +15,4 @@ fastapi
sse-starlette
matplotlib
fire
packaging

View File

@@ -44,8 +44,9 @@ def calculate_lr(
overwrite_cache=True,
)
)
tokenizer = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":

View File

@@ -32,8 +32,8 @@ def length_cdf(
overwrite_cache=True,
)
)
tokenizer = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):

View File

@@ -22,9 +22,9 @@ def get_requires():
extra_require = {
"deepspeed": ["deepspeed>=0.10.0"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
"galore": ["galore-torch"],
"vllm": ["vllm>=0.3.3"],
"badam": ["badam"],
"vllm": ["vllm>=0.4.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],

View File

@@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.6.2"
__version__ = "0.7.0"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass
class Response:
@@ -49,6 +47,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]: ...
@@ -58,6 +57,7 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...

View File

@@ -8,6 +8,8 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from numpy.typing import NDArray
from .base_engine import BaseEngine, Response
@@ -36,9 +38,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
return task.result()
async def achat(
@@ -46,18 +49,20 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, **input_kwargs)
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, **input_kwargs)
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -70,9 +75,10 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
yield new_token
def get_scores(

View File

@@ -14,7 +14,9 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper
from ..data import Template
@@ -30,7 +32,9 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
self.tokenizer = load_tokenizer(model_args)
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model(
@@ -42,13 +46,18 @@ class HuggingfaceEngine(BaseEngine):
def _process_args(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
@@ -95,6 +104,11 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(),
)
if processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
return gen_kwargs, prompt_length
@staticmethod
@@ -102,15 +116,17 @@ class HuggingfaceEngine(BaseEngine):
def _chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@@ -135,15 +151,17 @@ class HuggingfaceEngine(BaseEngine):
def _stream_chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@@ -199,6 +217,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -208,11 +227,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:
@@ -224,6 +245,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@@ -233,11 +255,13 @@ class HuggingfaceEngine(BaseEngine):
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
input_kwargs,
)
async with self._semaphore:

View File

@@ -1,19 +1,24 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from transformers.utils.versions import require_version
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count
from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available
from ..model import load_tokenizer
from ..model import load_config, load_tokenizer
from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if TYPE_CHECKING:
import torch
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -25,32 +30,59 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
config = load_config(model_args) # may download model from ms hub
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
infer_dtype = str(infer_dtype).split(".")[-1]
self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path,
trust_remote_code=True,
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count() or 1,
gpu_memory_utilization=model_args.vllm_gpu_util,
disable_log_stats=True,
disable_log_requests=True,
enforce_eager=model_args.vllm_enforce_eager,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
self.tokenizer = load_tokenizer(model_args)
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if model_args.visual_inputs:
# TODO: auto derive from config
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
self.image_feature_size = 576
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_input_shape"] = "1,3,336,336"
engine_args["image_feature_size"] = self.image_feature_size
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
self.lora_request = None
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
@@ -94,8 +126,21 @@ class VllmEngine(BaseEngine):
max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True,
)
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate(
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
)
return result_generator
@@ -107,10 +152,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, **input_kwargs)
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for request_output in generator:
final_output = request_output
@@ -132,10 +178,11 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, **input_kwargs)
generator = await self._generate(messages, system, tools, image, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text

View File

@@ -1,3 +1,4 @@
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
@@ -13,8 +14,23 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
return outputs
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
@@ -44,12 +60,16 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
@@ -84,6 +104,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
@@ -96,12 +117,13 @@ def align_dataset(
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
tools: "...",
images: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
@@ -114,6 +136,7 @@ def align_dataset(
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {}

View File

@@ -1,6 +1,6 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
from datasets import load_dataset, load_from_disk
@@ -16,7 +16,7 @@ from .utils import checksum, merge_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
@@ -115,11 +115,12 @@ def load_single_dataset(
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
@@ -149,7 +150,7 @@ def get_dataset(
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}

View File

@@ -28,6 +28,7 @@ class DatasetAttr:
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
system: Optional[str] = None
images: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
@@ -105,7 +106,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system"]
column_names = ["system", "images"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:

View File

@@ -1,14 +1,22 @@
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from ..extras.packages import is_pillow_available
from .utils import Role
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
@@ -18,6 +26,13 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0]
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
@@ -48,18 +63,25 @@ def preprocess_pretrain_dataset(
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
@@ -89,14 +111,16 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
@@ -141,17 +165,24 @@ def preprocess_packed_supervised_dataset(
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
@@ -172,22 +203,32 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
@@ -214,6 +255,8 @@ def preprocess_pairwise_dataset(
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
@@ -244,34 +287,54 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)

View File

@@ -343,7 +343,7 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
template = templates["empty"] # placeholder
else:
template = templates.get(name, None)
if template is None:
@@ -385,7 +385,8 @@ _register_template(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
),
)
@@ -502,6 +503,7 @@ _register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
@@ -512,6 +514,7 @@ _register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
@@ -526,6 +529,21 @@ _register_template(
)
_register_template(
name="cohere",
format_user=StringFormatter(
slots=[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@@ -534,6 +552,32 @@ _register_template(
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)\nThis is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@@ -566,6 +610,13 @@ _register_template(
)
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -574,10 +625,20 @@ _register_template(
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
efficient_eos=True,
force_system=True,
@@ -635,9 +696,36 @@ _register_template(
)
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
default_system="You are a helpful assistant.",
stop_words=["<|eot_id|>"],
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
@@ -669,10 +757,23 @@ _register_template(
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
@@ -699,13 +800,6 @@ _register_template(
)
_register_template(
name="vanilla",
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
@@ -776,7 +870,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate",
default_system="You are Zephyr, a helpful assistant.",
)

View File

@@ -78,9 +78,9 @@ def split_dataset(
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size

View File

@@ -21,7 +21,7 @@ from .template import get_eval_template
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)

View File

@@ -28,6 +28,10 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
MLLM_LIST = ["LLaVA1.5"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
@@ -45,6 +49,8 @@ TRAINING_STAGES = {
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
@@ -242,6 +248,44 @@ register_model_group(
)
register_model_group(
models={
"CommandR-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
},
"CommandR-Plus-104B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
},
"CommandR-35B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
},
"CommandR-Plus-104B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
},
},
template="cohere",
)
register_model_group(
models={
"DBRX-132B-Base": {
DownloadSource.DEFAULT: "databricks/dbrx-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
},
"DBRX-132B-Chat": {
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
},
},
module="Wqkv",
template="dbrx",
)
register_model_group(
models={
"DeepSeek-LLM-7B-Base": {
@@ -262,9 +306,11 @@ register_model_group(
},
"DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
},
"DeepSeek-Math-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
},
"DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
@@ -363,6 +409,23 @@ register_model_group(
)
register_model_group(
models={
"CodeGemma-2B": {
DownloadSource.DEFAULT: "google/codegemma-2b",
},
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
},
template="gemma",
)
register_model_group(
models={
"InternLM-7B": {
@@ -410,6 +473,16 @@ register_model_group(
)
register_model_group(
models={
"Jambda-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
},
)
register_model_group(
models={
"LingoWhale-8B": {
@@ -474,6 +547,42 @@ register_model_group(
)
register_model_group(
models={
"LLaMA3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
},
"LLaMA3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
},
"LLaMA3-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
},
"LLaMA3-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
},
},
template="llama3",
)
register_model_group(
models={
"LLaVA1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
},
"LLaVA1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
)
register_model_group(
models={
"Mistral-7B-v0.1": {
@@ -499,14 +608,21 @@ register_model_group(
register_model_group(
models={
"Mixtral-8x7B": {
"Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-Chat": {
"Mixtral-8x7B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
},
"Mixtral-8x22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
},
},
template="mistral",
)
@@ -515,18 +631,15 @@ register_model_group(
register_model_group(
models={
"OLMo-1B": {
DownloadSource.DEFAULT: "allenai/OLMo-1B",
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
},
"OLMo-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B",
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
},
"OLMo-7B-Chat": {
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
"OLMo-1.7-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
},
},
module="att_proj",
template="olmo",
)
@@ -534,7 +647,7 @@ register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
}
},
template="openchat",
@@ -582,6 +695,22 @@ register_model_group(
)
register_model_group(
models={
"Phi3-3.8B-4k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-4k-instruct",
},
"Phi3-3.8B-128k-Chat": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-128k-instruct",
},
},
module="qkv_proj",
template="phi",
)
register_model_group(
models={
"Qwen-1.8B": {
@@ -684,10 +813,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-Code-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
@@ -716,10 +853,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-Code-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
@@ -772,10 +917,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"Qwen1.5-Code-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
)
@@ -809,12 +962,15 @@ register_model_group(
models={
"StarCoder2-3B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
},
"StarCoder2-7B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
},
"StarCoder2-15B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
},
}
)
@@ -837,17 +993,53 @@ register_model_group(
register_model_group(
models={
"XuanYuan-6B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
},
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
},
"XuanYuan-2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
},
"XuanYuan-6B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan-2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
},
"XuanYuan-6B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
},
"XuanYuan-6B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
"XuanYuan-2-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
},
"XuanYuan-2-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
},
},
template="xuanyuan",
@@ -884,6 +1076,30 @@ register_model_group(
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-int8-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-int4-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
},
template="xverse",
)
@@ -976,21 +1192,9 @@ register_model_group(
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
"Zephyr-141B-ORPO-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
},
},
template="zephyr",
)
register_model_group(
models={
"Atom-7B": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
},
"Atom-7B-Chat": {
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
},
},
template="atom",
)

View File

@@ -66,7 +66,6 @@ def check_dependencies() -> None:
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
require_version("gradio>=4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
@@ -84,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else:
num_bytes = 1

View File

@@ -1,16 +1,23 @@
import importlib.metadata
import importlib.util
from typing import TYPE_CHECKING
from packaging import version
if TYPE_CHECKING:
from packaging.version import Version
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def _get_package_version(name: str) -> str:
def _get_package_version(name: str) -> "Version":
try:
return importlib.metadata.version(name)
return version.parse(importlib.metadata.version(name))
except Exception:
return "0.0.0"
return version.parse("0.0.0")
def is_fastapi_availble():
@@ -18,13 +25,17 @@ def is_fastapi_availble():
def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
def is_galore_available():
return _is_package_available("galore_torch")
def is_gradio_available():
return _is_package_available("gradio")
def is_jieba_available():
return _is_package_available("jieba")
@@ -37,6 +48,10 @@ def is_nltk_available():
return _is_package_available("nltk")
def is_pillow_available():
return _is_package_available("PIL")
def is_requests_available():
return _is_package_available("requests")
@@ -45,14 +60,14 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_sdpa_available():
return _get_package_version("torch") > version.parse("2.1.1")
def is_starlette_available():
return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

@@ -26,11 +26,11 @@ class DataArguments:
)
cutoff_len: int = field(
default=1024,
metadata={"help": "The cutoff length of the model inputs after tokenization."},
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
reserved_label_len: int = field(
default=1,
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
)
train_on_prompt: bool = field(
default=False,

View File

@@ -172,7 +172,7 @@ class GaloreArguments:
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."},
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="all",
@@ -204,7 +204,54 @@ class GaloreArguments:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_update_ratio: float = field(
default=0.0,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": """The mode of the mask for BAdam optimizer. \
`adjacent` means that the trainable parameters are adjacent to each other, \
`scatter` means that trainable parameters are randomly choosed from the weight."""
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": """The verbosity level of BAdam optimizer. \
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
@@ -256,11 +303,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"

View File

@@ -31,11 +31,11 @@ class GeneratingArguments:
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
)
max_length: int = field(
default=512,
default=1024,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
)
max_new_tokens: int = field(
default=512,
default=1024,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
)
repetition_penalty: float = field(

View File

@@ -22,7 +22,7 @@ class ModelArguments:
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=False,
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
@@ -33,6 +33,10 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
@@ -55,24 +59,32 @@ class ModelArguments:
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used for loading the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: bool = field(
default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."},
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
visual_inputs: bool = field(
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
@@ -129,6 +141,10 @@ class ModelArguments:
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: str = field(
default="cpu",
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
@@ -166,9 +182,15 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."

View File

@@ -8,10 +8,10 @@ import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies
from ..extras.packages import is_unsloth_available
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -67,6 +67,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
@@ -74,6 +77,35 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.0", "To fix: pip install vllm>=0.4.0")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
@@ -131,8 +163,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
@@ -151,21 +183,33 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed.")
if (
finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.quantization_bit is None
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
@@ -235,6 +279,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
@@ -266,18 +311,25 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.adapter_name_or_path is not None:
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support quantization.")
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
_verify_model_args(model_args, finetuning_args)
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
model_args.device_map = "auto"
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None:
model_args.device_map = {"": torch.device(model_args.export_device)}
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args
@@ -294,6 +346,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"

View File

@@ -1,8 +1,10 @@
from .loader import load_model, load_tokenizer
from .utils import find_all_linear_modules, load_valuehead_params
from .loader import load_config, load_model, load_tokenizer
from .utils.misc import find_all_linear_modules
from .utils.valuehead import load_valuehead_params
__all__ = [
"load_config",
"load_model",
"load_tokenizer",
"load_valuehead_params",

View File

@@ -5,11 +5,13 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
@@ -18,7 +20,11 @@ logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -32,9 +38,12 @@ def init_adapter(
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
@@ -66,6 +75,8 @@ def init_adapter(
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
@@ -79,7 +90,7 @@ def init_adapter(
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
@@ -100,6 +111,10 @@ def init_adapter(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
@@ -116,9 +131,15 @@ def init_adapter(
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder
)
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(
model,
adapter_to_resume,
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -129,9 +150,23 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
@@ -139,24 +174,21 @@ def init_adapter(
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if not finetuning_args.pure_bf16:
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)

View File

@@ -1,17 +1,20 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
from .utils.misc import register_autoclass
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .utils.unsloth import load_unsloth_pretrained_model
from .utils.valuehead import load_valuehead_params
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments
@@ -19,7 +22,17 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
@@ -29,22 +42,56 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
}
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer. Must before load_model.
Loads pretrained tokenizer.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
return tokenizer
if model_args.visual_inputs:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
else:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model(
@@ -55,45 +102,42 @@ def load_model(
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model. Must after load_tokenizer.
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
model = None
if is_trainable and model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
unsloth_kwargs = {
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
}
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.adapter_name_or_path:
model_args.adapter_name_or_path = None
logger.warning("Unsloth does not support loading adapters.")
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model is None:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:

View File

@@ -1,23 +1,22 @@
import math
import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Dict
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from .utils import QuantizationMethod, add_z3_leaf_module
from ..extras.misc import infer_optim_dtype
from .utils.attention import configure_attn_implementation, print_attn_implementation
from .utils.checkpointing import prepare_model_for_training
from .utils.embedding import resize_embedding_layer
from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
from .utils.visual import autocast_projector_dtype
if TYPE_CHECKING:
@@ -28,254 +27,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def _configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
) -> None:
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
return
logger.info("Using FlashAttention-2 for faster training and inference.")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", "flash_attention_2")
else:
init_kwargs["attn_implementation"] = "flash_attention_2"
else:
init_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.")
return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
@@ -289,25 +40,24 @@ def patch_config(
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
add_valuehead: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
_configure_attn_implementation(config, model_args, init_kwargs)
_configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_attn_implementation(config, model_args)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
if add_valuehead:
configure_valuehead(config)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
@@ -316,22 +66,23 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable:
setattr(config, "output_router_logits", True)
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs:
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map
if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
add_valuehead: bool,
) -> None:
gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and (
@@ -344,25 +95,21 @@ def patch_model(
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if is_trainable and getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if add_valuehead:
prepare_valuehead_model(model)
if model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer)
resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable:
_prepare_model_for_training(model, model_args)
prepare_model_for_training(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
add_z3_leaf_module(model, MixtralSparseMoeBlock)
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
if not model_args.use_unsloth:
print_attn_implementation(model.config)
try:
model.add_model_tags(["llama-factory"])

View File

@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla Attention implementation.")

View File

@@ -0,0 +1,94 @@
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@@ -0,0 +1,58 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))

View File

@@ -1,5 +1,5 @@
import math
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
@@ -7,19 +7,28 @@ from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
def llama_torch_attn_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -39,10 +48,11 @@ def llama_torch_attn_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -69,8 +79,9 @@ def llama_torch_attn_forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -97,8 +108,8 @@ def llama_torch_attn_forward(
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -117,7 +128,6 @@ def llama_flash_attn_forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -134,9 +144,10 @@ def llama_flash_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
@@ -192,7 +203,115 @@ def llama_flash_attn_forward(
return attn_output, attn_weights, past_key_value
def apply_llama_patch() -> None:
require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3")
LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
return llama_attention_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, :groupsz]
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _apply_llama_patch() -> None:
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")

View File

@@ -1,49 +1,18 @@
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, List
import torch
from transformers import PreTrainedModel
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import cached_file
from transformers.utils.versions import require_version
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ...extras.logging import get_logger
from .quantization import QuantizationMethod
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import ModelArguments
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, [module])
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora or galore.
@@ -100,34 +69,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
return module_names
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
from MoD import AutoMoDModelForCausalLM
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
def convert_pretrained_model_to_mod(
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
) -> "PreTrainedModel":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
return model

View File

@@ -0,0 +1,53 @@
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled():
return
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)

View File

@@ -0,0 +1,146 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ...hparams import ModelArguments
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

View File

@@ -0,0 +1,47 @@
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)

View File

@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"use_gradient_checkpointing": "unsloth",
}
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False
return model
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
if not is_trainable:
FastLanguageModel.for_inference(model)
return model

View File

@@ -0,0 +1,59 @@
from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_valuehead(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
if getattr(model.config, "model_type", None) == "llava":
setattr(model, "lm_head", model.language_model.get_output_embeddings())
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Tuple
import torch
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)

View File

@@ -1,5 +1,6 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
@@ -63,6 +64,11 @@ class CustomDPOTrainer(DPOTrainer):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)

View File

@@ -24,8 +24,9 @@ def run_dpo(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding(

View File

@@ -1,4 +1,5 @@
from collections import defaultdict
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
@@ -44,6 +45,10 @@ class CustomORPOTrainer(DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -24,8 +24,9 @@ def run_orpo(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding(

View File

@@ -1,6 +1,7 @@
import math
import os
import sys
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -124,6 +125,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.

View File

@@ -27,8 +27,9 @@ def run_ppo(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training

View File

@@ -1,3 +1,4 @@
from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
@@ -23,6 +24,10 @@ class CustomTrainer(Trainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -25,8 +25,9 @@ def run_pt(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

View File

@@ -1,5 +1,6 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -28,6 +29,10 @@ class PairwiseTrainer(Trainer):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -25,8 +25,9 @@ def run_rm(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import numpy as np
from transformers.utils.versions import require_version
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
@@ -33,10 +32,6 @@ class ComputeMetrics:
r"""
Uses the model predictions to compute metrics.
"""
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}

View File

@@ -1,5 +1,6 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -28,6 +29,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:

View File

@@ -28,8 +28,9 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if training_args.predict_with_generate:
@@ -47,6 +48,7 @@ def run_sft(
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(

View File

@@ -52,7 +52,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.")
tokenizer = load_tokenizer(model_args)
tokenizer = load_tokenizer(model_args)["tokenizer"]
get_template_and_fix_tokenizer(tokenizer, data_args.template)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
@@ -65,8 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
setattr(model.config, "torch_dtype", output_dtype)
for param in model.parameters():
param.data = param.data.to(output_dtype)
model = model.to(output_dtype)
model.save_pretrained(
save_directory=model_args.export_dir,

View File

@@ -5,7 +5,6 @@ from transformers import Trainer
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
@@ -57,9 +56,14 @@ def create_modelcard_and_push(
kwargs = {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory", finetuning_args.finetuning_type],
}
if data_args.dataset is not None:
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
if model_args.use_unsloth:
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
if not training_args.do_train:
pass
elif training_args.push_to_hub:
@@ -87,7 +91,7 @@ def create_ref_model(
)
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
tokenizer = load_tokenizer(ref_model_args)
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
@@ -96,7 +100,7 @@ def create_ref_model(
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
tokenizer = load_tokenizer(model_args)
tokenizer = load_tokenizer(model_args)["tokenizer"]
ref_model = load_model(
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
@@ -143,7 +147,7 @@ def create_reward_model(
)
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
tokenizer = load_tokenizer(reward_model_args)
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
@@ -166,8 +170,6 @@ def _create_galore_optimizer(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
require_version("galore_torch", "To fix: pip install galore_torch")
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model)
else:
@@ -217,7 +219,7 @@ def _create_galore_optimizer(
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params:
param_groups = [dict(params=[param])]
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
@@ -237,7 +239,7 @@ def _create_galore_optimizer(
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params),
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
]
@@ -252,11 +254,9 @@ def _create_loraplus_optimizer(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if finetuning_args.finetuning_type != "lora":
raise ValueError("You should use LoRA tuning to activate LoRA+.")
default_lr = training_args.learning_rate
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
decay_args = {"weight_decay": training_args.weight_decay}
embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
@@ -279,16 +279,76 @@ def _create_loraplus_optimizer(
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=param_dict["lora_a"], **decay_args),
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr),
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
return optimizer
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_block_every,
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
named_parameters_list=list(model.named_parameters()),
update_ratio=finetuning_args.badam_update_ratio,
mask_mode=finetuning_args.badam_mask_mode,
verbose=finetuning_args.badam_verbose,
include_embedding=False,
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
return optimizer
def create_custom_optimzer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
@@ -300,6 +360,9 @@ def create_custom_optimzer(
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
@@ -314,13 +377,12 @@ def create_custom_scheduler(
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
num_training_steps=num_training_steps * 2,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
def scheduler_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
scheduler_dict[param].step()
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)

View File

@@ -1,13 +1,13 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from numpy.typing import NDArray
from ..chat import ChatModel
from ..data import Role
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir
from .locales import ALERTS
@@ -17,6 +17,10 @@ if TYPE_CHECKING:
from .manager import Manager
if is_gradio_available():
import gradio as gr
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
@@ -29,13 +33,16 @@ class WebChatModel(ChatModel):
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
model_name_or_path = os.environ.get("DEMO_MODEL")
template = os.environ.get("DEMO_TEMPLATE")
super().__init__(dict(model_name_or_path=model_name_or_path, template=template))
infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
super().__init__(
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
)
@property
def loaded(self) -> bool:
return self.engine is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
error = ""
@@ -70,8 +77,9 @@ class WebChatModel(ChatModel):
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
flash_attn=(get("top.booster") == "flash_attn"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
)
@@ -79,7 +87,7 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang]
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
def unload_model(self, data) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_id("top.lang")]
if self.demo_mode:
@@ -107,6 +115,7 @@ class WebChatModel(ChatModel):
messages: Sequence[Dict[str, str]],
system: str,
tools: str,
image: Optional[NDArray],
max_new_tokens: int,
top_p: float,
temperature: float,
@@ -114,7 +123,7 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = ""
response = ""
for new_text in self.stream_chat(
messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
if tools:

View File

@@ -3,13 +3,13 @@ import os
from collections import defaultdict
from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import (
DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
MLLM_LIST,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS,
@@ -17,6 +17,11 @@ from ..extras.constants import (
DownloadSource,
)
from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
@@ -101,6 +106,10 @@ def get_template(model_name: str) -> str:
return "default"
def get_visual(model_name: str) -> bool:
return get_prefix(model_name) in MLLM_LIST
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value=[], choices=[], interactive=False)

View File

@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
@@ -14,15 +17,21 @@ if TYPE_CHECKING:
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True)
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=2)
with gr.Row():
with gr.Column():
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as image_box:
image = gr.Image(sources=["upload"], type="numpy")
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
@@ -40,19 +49,21 @@ def create_chat_box(
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, max_new_tokens, top_p, temperature],
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
return (
chat_box,
chatbot,
messages,
dict(
chat_box=chat_box,
role=role,
system=system,
tools=tools,
image_box=image_box,
image=image,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,

View File

@@ -1,10 +1,13 @@
import json
import os
from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
@@ -29,28 +32,38 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
except Exception:
return gr.Button(interactive=False)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
):
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
return gr.Button(interactive=False)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
return gr.Button(interactive=True)
else:
return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
return [json.loads(line) for line in f]
else:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file: str = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
if data_file.endswith(".json"):
data = json.load(f)
elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f]
else:
data = [line for line in f] # noqa: C416
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path):
data = _load_data_file(data_path)
else:
data = []
for file_name in os.listdir(data_path):
data.extend(_load_data_file(os.path.join(data_path, file_name)))
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)

View File

@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
@@ -18,7 +21,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({dataset_dir, dataset})

View File

@@ -1,12 +1,16 @@
from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
@@ -23,9 +27,11 @@ def save_model(
adapter_path: List[str],
finetuning_type: str,
template: str,
max_shard_size: int,
visual_inputs: bool,
export_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool,
export_dir: str,
export_hub_model_id: str,
@@ -41,6 +47,8 @@ def save_model(
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and adapter_path:
error = ALERTS["err_gptq_lora"][lang]
if error:
gr.Warning(error)
@@ -59,24 +67,28 @@ def save_model(
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=max_shard_size,
export_size=export_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
)
yield ALERTS["info_exporting"][lang]
export_model(args)
torch_gc()
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
export_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
export_legacy_format = gr.Checkbox()
with gr.Row():
@@ -95,9 +107,11 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.template"),
max_shard_size,
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit,
export_quantization_dataset,
export_device,
export_legacy_format,
export_dir,
export_hub_model_id,
@@ -106,9 +120,10 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
)
return dict(
max_shard_size=max_shard_size,
export_size=export_size,
export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id,

View File

@@ -1,10 +1,13 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...extras.packages import is_gradio_available
from .chatbot import create_chat_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
@@ -25,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
)
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
[chat_elems["image_box"]],
)
return elem_dict

View File

@@ -1,13 +1,16 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
from ..utils import can_quantize
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
@@ -27,14 +30,17 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
template = gr.Dropdown(choices=list(templates.keys()), value="default")
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
@@ -56,4 +62,5 @@ def create_top() -> Dict[str, "Component"]:
template=template,
rope_scaling=rope_scaling,
booster=booster,
visual_inputs=visual_inputs,
)

View File

@@ -1,13 +1,17 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
@@ -23,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})
@@ -134,7 +138,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
create_new_adapter = gr.Checkbox()

View File

@@ -1,6 +1,4 @@
from typing import Any, Dict, Generator
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config
@@ -10,6 +8,10 @@ from .runner import Runner
from .utils import get_time
if TYPE_CHECKING:
from gradio.components import Component
class Engine:
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode
@@ -29,7 +31,7 @@ class Engine:
return output_dict
def resume(self) -> Generator[Dict[Component, Component], None, None]:
def resume(self):
user_config = load_config() if not self.demo_mode else {}
lang = user_config.get("lang", None) or "en"
@@ -41,6 +43,7 @@ class Engine:
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
init_dict["infer.image_box"] = {"visible": False}
if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]}
@@ -55,7 +58,7 @@ class Engine:
else:
yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str) -> Dict[Component, Component]:
def change_lang(self, lang: str):
return {
elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter()

View File

@@ -1,5 +1,4 @@
import gradio as gr
from ..extras.packages import is_gradio_available
from .common import save_config
from .components import (
create_chat_box,
@@ -13,6 +12,10 @@ from .css import CSS
from .engine import Engine
if is_gradio_available():
import gradio as gr
def create_ui(demo_mode: bool = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False)
@@ -55,8 +58,8 @@ def create_web_demo() -> gr.Blocks:
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.add_elems("top", dict(lang=lang))
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
_, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", chat_elems)
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)

View File

@@ -129,6 +129,17 @@ LOCALES = {
"label": "加速方式",
},
},
"visual_inputs": {
"en": {
"label": "Visual inputs",
},
"ru": {
"label": "визуальные входы",
},
"zh": {
"label": "图像输入",
},
},
"training_stage": {
"en": {
"label": "Stage",
@@ -1073,6 +1084,17 @@ LOCALES = {
"placeholder": "工具列表(非必填)",
},
},
"image": {
"en": {
"label": "Image (optional)",
},
"ru": {
"label": "Изображение (по желанию)",
},
"zh": {
"label": "图像(非必填)",
},
},
"query": {
"en": {
"placeholder": "Input...",
@@ -1150,7 +1172,7 @@ LOCALES = {
"value": "清空历史",
},
},
"max_shard_size": {
"export_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file.",
@@ -1192,6 +1214,20 @@ LOCALES = {
"info": "量化过程中使用的校准数据集。",
},
},
"export_device": {
"en": {
"label": "Export device",
"info": "Which device should be used to export model.",
},
"ru": {
"label": "Экспорт устройство",
"info": "Какое устройство следует использовать для экспорта модели.",
},
"zh": {
"label": "导出设备",
"info": "导出模型使用的设备类型。",
},
},
"export_legacy_format": {
"en": {
"label": "Export legacy format",
@@ -1287,7 +1323,12 @@ ALERTS = {
"err_no_export_dir": {
"en": "Please provide export dir.",
"ru": "Пожалуйста, укажите каталог для экспорта.",
"zh": "请填写导出目录",
"zh": "请填写导出目录",
},
"err_gptq_lora": {
"en": "Please merge adapters before quantizing the model.",
"ru": "Пожалуйста, объедините адаптеры перед квантованием модели.",
"zh": "量化模型前请先合并适配器。",
},
"err_failed": {
"en": "Failed.",

View File

@@ -60,4 +60,5 @@ class Manager:
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
self._id_to_elem["top.visual_inputs"],
}

View File

@@ -4,9 +4,7 @@ import time
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator
import gradio as gr
import transformers
from gradio.components import Component # cannot use TYPE_CHECKING here
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_cuda_available
@@ -14,13 +12,20 @@ from ..extras.callbacks import LogCallback
from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_gradio_available
from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from .manager import Manager
@@ -62,7 +67,7 @@ class Runner:
if not model_path:
return ALERTS["err_no_path"][lang]
if len(dataset) == 0:
if not dataset:
return ALERTS["err_no_dataset"][lang]
if not from_preview and self.demo_mode:
@@ -117,8 +122,9 @@ class Runner:
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn=(get("top.booster") == "flashattn"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"),
@@ -217,8 +223,9 @@ class Runner:
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn=(get("top.booster") == "flashattn"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("eval.dataset_dir"),
dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"),
@@ -239,7 +246,7 @@ class Runner:
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
@@ -249,7 +256,7 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
@@ -263,19 +270,19 @@ class Runner:
self.thread.start()
yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
def preview_train(self, data):
yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
def preview_eval(self, data):
yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
def run_train(self, data):
yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
def run_eval(self, data):
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
def monitor(self):
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.aborted = False
self.running = True
@@ -332,7 +339,7 @@ class Runner:
yield return_dict
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
def save_args(self, data):
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
@@ -351,7 +358,7 @@ class Runner:
save_path = save_args(config_path, config_dict)
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
def load_args(self, lang: str, config_path: str):
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path)
if config_dict is None:

View File

@@ -3,21 +3,24 @@ import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional
import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import smooth
from .locales import ALERTS
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
if is_gradio_available():
import gradio as gr
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
if not callback.max_steps:
return gr.Slider(visible=False)