275 Commits

Author SHA1 Message Date
hiyouga
f0bff18324 Update publish.yml
Former-commit-id: 60b0633e29c9e701aa3813bd1fdc0282bd07f7c8
2024-06-19 20:46:33 +08:00
hiyouga
b631bdc5b7 release v0.8.2
Former-commit-id: 3050bbe51d46acd8473275d2713fc28932e4a3d3
2024-06-19 20:42:09 +08:00
hiyouga
c65f7e9bd5 fix jinja template
Former-commit-id: 0ebf2e2ee23918d28b0cbb20ba456732d6eedfbb
2024-06-19 20:03:50 +08:00
hiyouga
3e0fa4a8da fix templates
Former-commit-id: 6f357d59b73309c5955683008632e7f320e7dcb1
2024-06-19 17:44:05 +08:00
hiyouga
235ed85b0f fix bug
Former-commit-id: 412139eaa2fde98ba19e1257d21144382a59f0d6
2024-06-19 03:49:23 +08:00
hiyouga
1ca639a777 use prefix to replace force system
Former-commit-id: 731d9a964f1c3dbfb83825524d697831e691fb9d
2024-06-19 03:39:52 +08:00
hiyouga
e36a994fe6 fix tool formatter, allow parallel function #4362
Former-commit-id: b8f16c976db4ecec1cc8558851c8cbfb6a5b7e9c
2024-06-19 03:23:51 +08:00
hoshi-hiyouga
19ffcfea76 Merge pull request #4173 from mMrBun/main
Implemented the tool_formatter and tool_extractor for glm4 and Qwen2 tool_format

Former-commit-id: 36b02ceed40198ecd5d559ee4ebef9205442ded2
2024-06-19 03:18:55 +08:00
hiyouga
85f3a09c83 tiny fix
Former-commit-id: bb750fa3dde03ec024ae75596ecd4b884cb126c6
2024-06-18 23:32:18 +08:00
hoshi-hiyouga
60b9a9c1fa Merge pull request #4314 from EliMCosta/patch-2
Fix Dockerfile

Former-commit-id: a123a42d98f5c49446762c1d4cfc674d2e4f61b1
2024-06-18 23:30:59 +08:00
hoshi-hiyouga
984e38575c Merge pull request #4309 from EliMCosta/patch-1
Add Magpie and Webinstruct dataset samples

Former-commit-id: 70966de5d4df51a41fef1da5a919dd622aa9c86c
2024-06-18 23:30:19 +08:00
hiyouga
665df5d733 add deepseek coder v2 #4346
Former-commit-id: d83d3846d8e3bf5c40d4b90c24e2c5909ec61864
2024-06-18 22:53:54 +08:00
hiyouga
4bc0bea0e9 fix #4357
Former-commit-id: a6741bba8cebd16a6a3f97a2dc81057d0e27eb39
2024-06-18 22:42:45 +08:00
hoshi-hiyouga
5cfa342f01 Merge pull request #4334 from zzxzz12345/bugfix/add-pandas-versions
Update requirements.txt

Former-commit-id: 219eb5b346bce7e13c2c3511c1638f9dde595787
2024-06-18 22:30:35 +08:00
hoshi-hiyouga
c106cc24e4 Update requirements.txt
Former-commit-id: da8684f9f0b0103d4fa81279343a48ecd0fcc0cd
2024-06-18 22:27:24 +08:00
hiyouga
372da52d4a fix #4335
Former-commit-id: 2ab449adbb160f339a0586edeb846fa311ad8382
2024-06-18 22:08:56 +08:00
hiyouga
875270b851 lint
Former-commit-id: a19a7ac99af62b6715c96274f6350b124a784331
2024-06-17 22:35:56 +08:00
hiyouga
43fab306b6 update chat engine #4335
Former-commit-id: b163df7de48777e4319c9ccc736b0acdd5f473ed
2024-06-17 19:07:17 +08:00
hiyouga
77242f4169 update readme
Former-commit-id: 07c629f77c3978f339402e578cde1aede3f37699
2024-06-17 18:47:24 +08:00
hiyouga
60d9896a70 fix #4326
Former-commit-id: 3c2c45812a720d92f7f5b15b9f03370fe6bf069e
2024-06-17 18:17:48 +08:00
hiyouga
485a80d294 tiny fix
Former-commit-id: 2289436567a7860d25d9da0afb39e4a3e5e83839
2024-06-17 17:47:25 +08:00
胡翀
63bfe9967e Update requirements.txt
add pandas version requirements

Former-commit-id: ed1cf559aa2d02588aacf55a17b439473651f626
2024-06-17 16:45:57 +08:00
Eli Costa
a720b82e63 Fix Dockerfile
Adds the commands to correctly execute LLama-Factory servers

Former-commit-id: 22af40f0895a6f88709a495febeca8507d41d989
2024-06-16 19:16:23 -03:00
Eli Costa
d3b0048d8c Update README_zh.md
Fix details tag in datasets menus

Former-commit-id: d79c1bd4806e9ea13115fabebf9da2d19b0a52be
2024-06-16 11:34:31 -03:00
Eli Costa
9a0aca42a5 Update README_zh.md
Add Magpie and WebInstruct to README

Former-commit-id: 6cf5323959fe9500ba06ab28980fcc8f62e1373f
2024-06-16 11:22:06 -03:00
Eli Costa
5e802b0645 Update README.md
Add Magpie and Webinstruct to README

Former-commit-id: 2b32b9263f12605e48e11dce9b5fbb746d790745
2024-06-16 11:19:25 -03:00
hoshi-hiyouga
ca67b7a568 Update parser.py
Former-commit-id: d10c97193d08bd368aca1a72f0d1d8a96c76765d
2024-06-16 02:57:00 +08:00
hiyouga
76cd879c84 update pr template
Former-commit-id: 0b7c29674fda10c0ac87e0a0c75990feabb5a3de
2024-06-16 01:43:43 +08:00
hoshi-hiyouga
e0c049e590 Merge pull request #4307 from hiyouga/pissa
Support pissa

Former-commit-id: e7c0eefe96540c106162f5d252476b10b97ae696
2024-06-16 01:41:50 +08:00
hiyouga
727943f078 fix tol
Former-commit-id: bdb54bcb477126687db789bd89f2df84e424a2a3
2024-06-16 01:38:44 +08:00
hiyouga
8393b08666 Update tests.yml
Former-commit-id: 82e83615a706293abbf266d11c57caedafdd4c5b
2024-06-16 01:22:23 +08:00
hiyouga
9049f72d2f increase tol
Former-commit-id: c29071445e34aed23123fdf883a4d877744a1b0e
2024-06-16 01:21:06 +08:00
hiyouga
32f45c9e91 support pissa
Former-commit-id: ef8e45f2eaf466c54e9a671512a2974575677b08
2024-06-16 01:08:12 +08:00
hiyouga
05f3a3c944 tiny fix
Former-commit-id: f7f440986b0ae3b38ea9f2da80789629d4f79ea1
2024-06-16 01:06:41 +08:00
hiyouga
14f7bfc545 use fixture
Former-commit-id: 10761985691b9f934f7689c1f82aa6dd68febcca
2024-06-15 20:06:17 +08:00
hiyouga
7f90b0cd20 add tests
Former-commit-id: 484634ee9c982e82e919ff67d507e0210345182d
2024-06-15 19:51:20 +08:00
hiyouga
308abfec6c add minicpm #4227
Former-commit-id: e1bb18ce60be9a1b203989def30f1b9194286325
2024-06-15 17:58:52 +08:00
hiyouga
bb88536166 add license
Former-commit-id: 69cfc98d7c81756a5ab6bf962240e393e449fef0
2024-06-15 17:54:33 +08:00
hiyouga
d2df3f2d6e update readme
Former-commit-id: a43d302aa79cbfb9b0606e855b4c1af6865d8e68
2024-06-15 05:13:16 +08:00
hiyouga
2abfad9c1f fix #4271
Former-commit-id: 03707e78d29bfcf5d395a64bb38632bdb3ff47ce
2024-06-15 05:11:33 +08:00
hiyouga
2af932d969 disable DP
Former-commit-id: c18fd609d268389f3e65274992045a6c9f8e6c1f
2024-06-15 04:57:19 +08:00
hiyouga
c29fa61a9c fix #4292
Former-commit-id: 4cd4c179d24eab0fcaec2b29b9dd71970f877fe8
2024-06-15 04:47:13 +08:00
hiyouga
a30931fe0f fix #4295
Former-commit-id: 08f657868f9d605b837c5d8c2946a25cc05c8735
2024-06-15 04:34:55 +08:00
hiyouga
3ff9b87012 add test cases
Former-commit-id: 731176ff34cdf0cbf6b41c40c69f4ceb54c2daf6
2024-06-15 04:05:54 +08:00
hiyouga
f4f315fd11 Update README.md
Former-commit-id: f8d701cd3ce2e56f95b4f5439b8b48d5b62e0d2b
2024-06-13 16:02:21 +08:00
hiyouga
530165d9a5 update examples
Former-commit-id: d6bf6231290d79eb3a63e711f18fa711ef18a4f6
2024-06-13 03:26:10 +08:00
hiyouga
dbd1458adf add quant check in webui export tab
Former-commit-id: 6455ca07061ae9858cd7bc996b28be1fde697a3d
2024-06-13 03:19:18 +08:00
hiyouga
dedefecd2b Update llama3_full_sft_ds3.yaml
Former-commit-id: e715af62d521112d9c155cfa91fbb42fa0e77710
2024-06-13 03:16:20 +08:00
hiyouga
46f441dd37 update examples
Former-commit-id: 19681f93db399d695aa8e35f8ec2a9e720875baa
2024-06-13 03:15:06 +08:00
hiyouga
49b58fd6af fix #4221
Former-commit-id: 05a3be4853b941909e7d193c31e8d62c8c5f879b
2024-06-13 02:48:21 +08:00
hiyouga
103a507b39 fix #4209
DeepSpeed ZeRO3 has inflight param error when calling model.eval()


Former-commit-id: 4be013f18ea6a35b5a11db98db5f0670ffb41619
2024-06-13 02:25:50 +08:00
hiyouga
0a75224f62 clean code
Former-commit-id: f54cafd5c7f0383370d1a2f357834a61a97397ce
2024-06-13 01:58:16 +08:00
hoshi-hiyouga
04d7629abf Merge pull request #4246 from hzhaoy/adapt-vllm-v0.5.0
adapt vllm==0.5.0

Former-commit-id: 1068e25fc8b89f11cc79b164ee4aef9ce137ad4c
2024-06-13 01:54:02 +08:00
hiyouga
1b6786a21f add neo-sft dataset
Former-commit-id: 34863fa7cb641ceca92e3a2eec914126db537b62
2024-06-13 01:00:56 +08:00
hiyouga
5080f2314c fix lint
Former-commit-id: b170165679317af2b3f03633afac27661b3deb06
2024-06-13 00:48:44 +08:00
hiyouga
41beb7f0a3 fix docker compose usage
Former-commit-id: 59a5bd5d5c8d2a44e2dad26b74e77a45e109c8d6
2024-06-13 00:07:48 +08:00
hzhaoy
799873aa14 adapt vllm==0.5.0
Former-commit-id: 02afd9ff64f23e6707ac739ae1269f41bd70c340
2024-06-12 18:29:03 +08:00
hiyouga
fe2c7eaa93 update readme
Former-commit-id: a436aaa83f0cf12c8f404459e5486f9369d538ec
2024-06-12 17:39:12 +08:00
hiyouga
6392d45ea7 fix #4242
Former-commit-id: cf260e7af03f49aa5e3d6daf3b27738ff9b9bcb8
2024-06-12 16:50:11 +08:00
hoshi-hiyouga
c60ea675d7 Merge pull request #4234 from kimdwkimdw/patch-1
Support vllm==0.5.0

Former-commit-id: 0a9da057c9e7ef11cd709b20263c3d2e4c2d72ed
2024-06-12 16:39:09 +08:00
Arthur Kim
16c7c92396 Support vllm==0.5.0
Former-commit-id: e7a8ffd7af21bc3759f055033ba2209fa7a1be0e
2024-06-12 16:49:12 +09:00
hoshi-hiyouga
7598b37543 Merge pull request #4204 from dignfei/main
fixbug:llama3在增量预训练时应该使用<|end_of_text|>标识文本的结束

Former-commit-id: e566342636faf0031a0ba5d5dd4fcff8401a2b76
2024-06-11 17:06:10 +08:00
hoshi-hiyouga
cc9717e2f2 Update pretrain.py
Former-commit-id: e2317b2a84149e39fddfd6366be3de23dfb71f82
2024-06-11 17:02:14 +08:00
hiyouga
08f2f99f4b fix deepspeed version
Former-commit-id: 938a69bb07d4de7d82928ff01c582032162c1480
2024-06-11 16:52:36 +08:00
d
77bf3d66c7 经过大量的增量预训练,进行对比试验,发现这个bug:llama3在预训练时使用的tokenizer.eos_toke是'<|end_of_text|>' ,这里在每条数据后面也得用这个,而不是'<|eot_id|>',否则很容易导致严重的性能下降
Former-commit-id: ef470561f742b16eaa0f99c4cadecd7c84ce6bd2
2024-06-11 16:23:40 +08:00
hiyouga
f14f67f803 Update bug-report.yml
Former-commit-id: bb022cd867ebf2593e40fc6ba43b768603b129a3
2024-06-11 15:40:21 +08:00
hiyouga
820b6e7b32 fix #4198
Former-commit-id: 945d2c6cc73542adf9272ebd9aa332ea2c1c7361
2024-06-11 15:38:38 +08:00
hiyouga
27aece94cf tiny fix
Former-commit-id: c4b2e263d9cefbad0fbc5de72422e4ef8edbcb54
2024-06-11 12:48:53 +08:00
hoshi-hiyouga
3f2508be92 Merge pull request #4191 from iamthebot/al--add_manifest_for_reqs
Add MANIFEST.in so requirements.txt is present in sdist

Former-commit-id: fd6d1c3fce855d1ef7396cf33af9f12eadc5a878
2024-06-11 10:41:15 +08:00
Alfredo Luque
fce11bb386 add manifest so requirements.txt in sdist
Former-commit-id: b501a3c56c51786c3006a2aca15a145641a4556c
2024-06-11 00:07:06 +00:00
hiyouga
2723438531 tiny fix
Former-commit-id: b5e9711ef375cc323fc083e742cccfc974550416
2024-06-11 01:04:16 +08:00
hiyouga
f330b73682 set dev version
Former-commit-id: 16c47cc15226119e33e46ba0f2f6ccb37072257f
2024-06-11 00:50:53 +08:00
hiyouga
0f1e592326 release v0.8.1
Former-commit-id: 875a34f492701d1c644facbe9ede411af2931513
2024-06-11 00:44:26 +08:00
hiyouga
4d7dd0330d fix #4160
The split heads should be concatenated in dim=2


Former-commit-id: 4b3f247f270d44df9fe226cfe0dabfb7fcd2deda
2024-06-11 00:37:17 +08:00
hiyouga
ea2ca2777f fix #4145
Fix the docker image


Former-commit-id: a9838281156fe870bfcde5d1f7afc15264fd4aad
2024-06-11 00:19:17 +08:00
hiyouga
4b2b92fd9a update evaluator
Former-commit-id: bb8661e62481ff7027b8969f3d8a6a17290c9da3
2024-06-10 23:56:00 +08:00
hiyouga
784088db3f fix #2666
Former-commit-id: f121d5c4f94af9f165132c4309cb9bdc8217d985
2024-06-10 21:24:15 +08:00
hoshi-hiyouga
0ecf0d51e3 Merge pull request #4167 from yzoaim/branch
fix README

Former-commit-id: 1a877b0fbf54478dbf905fb3e84bd079a55bb725
2024-06-10 16:24:33 +08:00
mMrBun
bc04ca464a Optimize the handling of QWEN2 in scenarios involving multiple tool calls.
Former-commit-id: 48f870edc96ada40360f7e6e67cbf58805295b33
2024-06-10 02:00:14 +08:00
mMrBun
44829df762 Removed unnecessary comments.
Former-commit-id: 2b81252aa693871098931cd7873ef83ef4922ba5
2024-06-09 18:25:22 +08:00
mMrBun
94ddfa66c0 Merge branch 'hiyouga:main' into main
Former-commit-id: c25734d874a36222e0a540a2c994bbda73008b27
2024-06-09 18:17:24 +08:00
mMrBun
8db8ed5a41 Implemented the tool_formatter and tool_extractor for glm4 tool_format
Former-commit-id: db7fa4490ea7f6966418d2879c895cbc1763b16d
2024-06-09 18:16:15 +08:00
-.-
041ecd0de1 fix README
Former-commit-id: fa30028c0b83c38610b596209493a748b8ca0928
2024-06-08 23:51:56 +08:00
hiyouga
d812249db7 add pr ci
Former-commit-id: 9b05bb8540b946d0c74bf804bcafc4a785d22c47
2024-06-08 21:25:35 +08:00
hiyouga
88528f1a87 Update tests.yml
Former-commit-id: e90f0cc30d6bb819246ccc08935c39e714c179a1
2024-06-08 21:15:36 +08:00
hiyouga
82533114a7 update git workflows
Former-commit-id: 5a3f26bc53433caa98b2a66294becaf156280a4c
2024-06-08 21:11:32 +08:00
hiyouga
6d9fbb3fa9 fix llamafactory-cli env
Former-commit-id: b0515e5f42831b67d1f4d049999ecb68756e66db
2024-06-08 07:15:45 +08:00
hiyouga
9953ae3d03 set dev version
Former-commit-id: 08b7fe1c452cc99264ff0312e310b579590c6a45
2024-06-08 06:46:09 +08:00
hiyouga
c0c387e4db release v0.8.0
Former-commit-id: 004db680b9e3996ec511ee818df6c0c02bf13603
2024-06-08 05:20:54 +08:00
hiyouga
ae60ea15da add ultrafeedback and fineweb #4085 #4132
Former-commit-id: 968e4992e2f2a3ccba73e8668f1654ddc6eb0034
2024-06-08 02:42:34 +08:00
hiyouga
72cd1123a8 fix ci
Former-commit-id: 3f4d293fd861d765edb2040f80d16f99a5e1e3c6
2024-06-08 02:00:44 +08:00
hiyouga
1364190a66 fix ci
Former-commit-id: 95aceebd61d195be5c980a919c12c59b56722898
2024-06-08 01:57:36 +08:00
hiyouga
6d17c59090 add ci
Former-commit-id: 3ea3acdadaa54abe33d93538580196cfdd91ee56
2024-06-08 01:48:30 +08:00
hiyouga
e0f2c0b5dc init unittest
Former-commit-id: 1c6f21cb8878ced043fe0b27c72cad2ef6ee990e
2024-06-08 01:35:58 +08:00
hiyouga
073e34855d Delete .readthedocs.yaml
Former-commit-id: dd3ee514216a9a329519c58d79208040adcad126
2024-06-08 00:58:10 +08:00
hiyouga
ff9ba70bb8 reorganize adapter code
Former-commit-id: b26c2df9d97f4efffccbf7d28de13619b43f10dd
2024-06-08 00:47:23 +08:00
hoshi-hiyouga
adbebb0e3f fix #4139
Former-commit-id: c025a4d74f293c14c2705e68af20a82a84608520
2024-06-08 00:45:02 +08:00
hiyouga
3f6b3eed98 add resume args in webui
Former-commit-id: 1d86ad768b1f36e54b4c2a9f18f6ea5a7df04c90
2024-06-08 00:22:16 +08:00
hiyouga
f45e81e186 fix #4137
Former-commit-id: cdc0d6f5a2e5040e145c82c4801f37bd76529047
2024-06-07 19:16:06 +08:00
hiyouga
ba648fd003 tiny fix
Former-commit-id: 0621bcad1dfbe8ce2464f741d4256c5df2a8d1b6
2024-06-07 05:19:21 +08:00
hiyouga
b0e5a76f4c fix ppo trainer save zero3 model
accelerator.get_state_dict(ds_model) should be called at all ranks


Former-commit-id: 3a0f60f0aa072531e4ae5819ec00c8fa42aa0913
2024-06-07 05:14:19 +08:00
hiyouga
8692796c9b fix ppo in trl 0.8.6
Former-commit-id: 5e0d66a0d80b4bd4a8506e2317209d8fb9d25ff6
2024-06-07 04:48:29 +08:00
hiyouga
d0edcde4ea fix #4120
Former-commit-id: 2a44da678a5e360a9c0f9056397ac9e801329321
2024-06-07 04:18:05 +08:00
hiyouga
8c4c2e580c update data processors
Former-commit-id: 04b138cbcb8b9a72e4bbda6c65843bb459e525e7
2024-06-07 04:15:40 +08:00
hoshi-hiyouga
07f33e7641 Merge pull request #4009 from AlongWY/main
supervised packing with greedy knapsack algorithm

Former-commit-id: 5ded166b39a75a98ded5733678f5a1eab7d4cc71
2024-06-07 03:48:46 +08:00
hoshi-hiyouga
1998c641af Update supervised.py
Former-commit-id: 04b6c2a754e602e0b698cfe6c255c2f2486d8865
2024-06-07 03:42:08 +08:00
hoshi-hiyouga
be1e5f9d62 Update supervised.py
Former-commit-id: 49993c4f4e1f871a22ff0196afe60026b668a4dc
2024-06-07 03:38:23 +08:00
hoshi-hiyouga
fdeec6db52 Update supervised.py
Former-commit-id: 67625b5278a839c12a3e4245f9e90af67d8b11b4
2024-06-07 03:38:04 +08:00
hiyouga
a4d335b42f add qwen2 models
Former-commit-id: 49cb694d02c876e3740a003a8b332349f4310ad3
2024-06-07 00:22:57 +08:00
hiyouga
fcb134e144 rename files
Former-commit-id: e1a8431770fc36c0c9ee7fed4abbc3d7fdcc5efd
2024-06-07 00:09:06 +08:00
hiyouga
a47e24222a add DISABLE_TORCHRUN option
Former-commit-id: bcc574b479c2101438723aadead42743d4378776
2024-06-06 23:44:58 +08:00
hoshi-hiyouga
b96b995620 Merge pull request #4082 from MengqingCao/bugfix
Fix #4077

Former-commit-id: 288028c3fb6bb1b58d1b7f4e8b90108c9bbf27d1
2024-06-06 23:38:40 +08:00
hoshi-hiyouga
c231706aa5 Update cli.py
Former-commit-id: 32190507534adf5f505858b3af2b592ca6568ac7
2024-06-06 23:38:09 +08:00
hiyouga
35b5117a59 fix ppo+zero3 #3108
Former-commit-id: 33a93cc29e3e57bf001515000c0a70c112573dea
2024-06-06 23:30:07 +08:00
hiyouga
80f716bc10 fix torch gc
Former-commit-id: e173799d057598e5692a407601c30d8ce1513461
2024-06-06 20:30:25 +08:00
hiyouga
ca95e98ca0 fix ppo dataset bug #4012
Former-commit-id: 7fc51b2e93698ae5e012566af8481f4d861c873d
2024-06-06 19:03:20 +08:00
hiyouga
d5559461c1 update trainers
Former-commit-id: b7f6c4a171293cf4f3e88f15a811f847342f84ee
2024-06-06 18:45:49 +08:00
hiyouga
f4acd81e2f fix base64 image read #4061
Former-commit-id: 66ccb2a27a04296b4600f2c85f428071bf14eeb0
2024-06-06 17:29:19 +08:00
hiyouga
31feb6e26c update readme
Former-commit-id: cc331fa2d28afe081937c50ea83d63add21d4e3a
2024-06-06 16:59:18 +08:00
hiyouga
7d5c0a069c update readme
Former-commit-id: fb1f709af5199976e63d7188e088e33c75d19bfe
2024-06-06 16:25:42 +08:00
hiyouga
937f49ec3d lora modules: all by default
Former-commit-id: 52c4ae87c7f4312704c31ef26b079b2c5b95ea5f
2024-06-06 03:53:28 +08:00
hiyouga
abc2a73a33 add codestral 22B
Former-commit-id: b011c7f527a57cb1d21c4e2c9631c2fb62bb835e
2024-06-06 03:42:50 +08:00
hiyouga
5e1bf7572c lint
Former-commit-id: 9030501eaef97ea249347198272adf0d709503ec
2024-06-06 03:33:44 +08:00
hoshi-hiyouga
8fdb32d0a3 Merge pull request #4066 from injet-zhou/main
add throughput entry to training log

Former-commit-id: d2816f343f405f3fab09f2a8eade774b886e8f92
2024-06-06 03:32:04 +08:00
hoshi-hiyouga
c709d5f7db Merge pull request #4080 from MengqingCao/npu
Add npu option for model exporting

Former-commit-id: 07fc67193ef6bcb8e8a392aff0c57a2eb36832bf
2024-06-06 03:15:44 +08:00
hoshi-hiyouga
f5b2749ec2 Update export.py
Former-commit-id: 694833c1104d13929d4f181f014a121f25955dc5
2024-06-06 03:14:46 +08:00
hoshi-hiyouga
ee5853c565 Update model_args.py
Former-commit-id: 09c0afd94a8a5f5b45a61b32c983d50e1b9e2941
2024-06-06 03:14:23 +08:00
hoshi-hiyouga
6ec6df8a5f Merge pull request #4053 from hzhaoy/feature/add_select_config_file
Support selecting saved configuration files

Former-commit-id: 568ef3cf2a793f268cbe01c39dec418a13e61ecd
2024-06-06 03:06:03 +08:00
hiyouga
fc95800840 add vllm_dtype arg #3387 #3717
Former-commit-id: a0dd3a6351bb78541d40fec1d2fc457d803c86a4
2024-06-06 02:53:27 +08:00
hiyouga
765715af21 support train from scratch #4033 #4075
Former-commit-id: 1290b9d01077e62f8de7a23637daa2586cc82bfa
2024-06-06 02:43:19 +08:00
hiyouga
639a7f6796 support image input in api #3971 #4061
Former-commit-id: c70aaf763ef22fb83ce3635e8ffd5ec4c89c1cb0
2024-06-06 02:29:55 +08:00
hiyouga
35379c7c0e update train hparams
Former-commit-id: 1ca9fce55b55bf209f4b76152b586731932a3f39
2024-06-06 01:49:20 +08:00
hiyouga
d992f5353f fix setup
Former-commit-id: b2b80d434fcc0c3838d229098e1c21d26632204c
2024-06-06 01:39:02 +08:00
hiyouga
875eef45f3 add llamafactory-cli env
Former-commit-id: 1df077184845ff5f394b9324d46f8c382869e590
2024-06-06 01:28:14 +08:00
hiyouga
556a4aa972 fix #4090
Former-commit-id: d9f15f30a8f4bc64778a5c96baeb6801700d7a2c
2024-06-06 00:50:32 +08:00
MengqingCao
8dc1969111 modify export_device option
Former-commit-id: b2fc4a5499e21a5b9622c2285402efef6e27a74d
2024-06-05 09:37:36 +00:00
hiyouga
b74c229498 fix #4079
Former-commit-id: fda732d7f4616373844c97beff416880260f49db
2024-06-05 16:56:54 +08:00
hiyouga
3dbca466fd update readme
Former-commit-id: 02d34db29a7a35c25711d49e98fd3167a2f4dfe7
2024-06-05 16:32:32 +08:00
MengqingCao
ce6f7fdb82 fix #4077
Former-commit-id: fedbe92f3b56294acc6c49f9a51e369cf2de3ead
2024-06-05 08:03:30 +00:00
hiyouga
7528bc1bc0 support glm-4
Former-commit-id: a10f4718fbf3f3c89dc7eb31cb8e1a46ca6adda5
2024-06-05 15:16:38 +08:00
MengqingCao
9dd5f7d642 add npu for model export
Former-commit-id: ce020b6eb3f35c1db37ee4835e694eddcd0f59b0
2024-06-05 07:06:40 +00:00
faddddeout
99ecb0daaf add throughput entry to log
Former-commit-id: 691f999f64c7bac78761e4354f89816d2f0d46fc
2024-06-04 11:04:29 +00:00
hzhaoy
39d8d7995a add: support selecting saved configuration files and loading training parameters
Former-commit-id: 5c9b17c1dc9093da0ea813642bce9b5c9ae96274
2024-06-04 10:33:43 +08:00
hiyouga
2ac2cde03e tiny fix
Former-commit-id: f9d50501aac1f60a3b445ca3fee9aa60995461ee
2024-06-04 00:31:10 +08:00
hiyouga
aa6c3766de fix #3873
Former-commit-id: 1ac325b4d682bb493573c18bb0b67ceae8d0d372
2024-06-04 00:21:50 +08:00
hiyouga
f4f5d7e3ce fix #3992
Former-commit-id: a48321fbf5196b88a11106cf74a74fbcea2ea50b
2024-06-04 00:17:36 +08:00
hiyouga
efbf6018d3 fix abort in webui DDP mode
Former-commit-id: b90ac72d753b13a3eed9cb8b898fac2f2fe5153f
2024-06-04 00:10:24 +08:00
hoshi-hiyouga
1090bb8bf3 Merge pull request #3987 from injet-zhou/main
Fix cann't interrupt training when using multi GPUs in webui

Former-commit-id: 455bb158b0e600723d2afaa2070b71178f2f5188
2024-06-04 00:04:07 +08:00
hiyouga
26bc79f971 fix #4043
Former-commit-id: 67af68f4fc5232760c57b3a0ae780628da09db6a
2024-06-03 23:30:37 +08:00
hiyouga
4c1f015eca remove gc warnings in DPO&KTO
Former-commit-id: b649bdcbafb464a638387429b770fe258b41f8af
2024-06-03 22:53:54 +08:00
hoshi-hiyouga
0655a183d3 Merge pull request #4045 from enji-zhou/feature/add_kto
fix KTO Trainer Sampler

Former-commit-id: 8e235beb9cf4939c06ccb753b047326a9839e77f
2024-06-03 22:09:25 +08:00
hoshi-hiyouga
7754024e9b Update trainer.py
Former-commit-id: 8565d4b43db905374c328ae57c71fc226980d14f
2024-06-03 22:08:38 +08:00
enji.zhou
b4913569a8 fix KTO Trainer Sampler
Former-commit-id: 39eb1bfa272011554322e9bb2534f83b68282a70
2024-06-03 21:32:38 +08:00
hoshi-hiyouga
eae9f09ca8 Merge pull request #4006 from Uminosachi/scheduler-kwargs
Set scheduler_specific_kwargs to get_scheduler

Former-commit-id: c6ed1955fd8990ddb960750913c9d8b13fe0ace3
2024-06-03 19:27:53 +08:00
hiyouga
8264e5ceaa update placeholder in issue template
Former-commit-id: 5503a90d7e38273b67129e0b9eb62bd1fd23154f
2024-06-03 19:24:10 +08:00
hoshi-hiyouga
b76f319e45 Merge pull request #4011 from statelesshz/issue-template
Update bug-report.yml

Former-commit-id: 1fbc46f45ae4e673f0b20b5eacab3d81d1053807
2024-06-03 19:20:43 +08:00
hiyouga
82d744716a fix #4005 #4013
Former-commit-id: 8608fa268cde5cddf8d0c6c2eb2cb5fa246c1831
2024-06-03 19:12:29 +08:00
hoshi-hiyouga
1a3764ab8f Merge pull request #4007 from xu-song/patch-3
Update model_args.py

Former-commit-id: d88b3a0f2707bcc964f642d348295b99f7c796f8
2024-06-03 18:54:37 +08:00
hiyouga
d2ede9d393 fix #4022
Former-commit-id: 9541f2f1f1b7d7877eb734f051048e52003a3430
2024-06-03 18:38:36 +08:00
hiyouga
5690f513fc bump versions
transformers 4.37.2->4.41.2
datasets 2.14.3->2.16.0
accelerate 0.27.2->0.30.1
peft 0.10.0->0.11.1
trl 0.8.1->0.8.6


Former-commit-id: 5f1e041f7295bf42a41dd4d9e7f0c42fcc37fed2
2024-06-03 18:29:38 +08:00
hiyouga
123a845209 fix data loader hint
Former-commit-id: 25b56126a11591b0155e2f72b673dd8f45a6c8c9
2024-06-03 18:28:27 +08:00
ylfeng
b1b7d735b3 remove empty line
Former-commit-id: 3164710971a6d6545629f5bf133f98de5ff0991a
2024-05-31 21:43:08 +08:00
ylfeng
230c69f7ce fix eos
Former-commit-id: 6e236c952958cbfe50b5dcb7b8eff6aea8477922
2024-05-31 21:40:41 +08:00
ylfeng
bfc43558ef supervised packing with greedy knapsack algorithm
Former-commit-id: 24d12396c9aabd49da0b08719068f24679111cc6
2024-05-31 15:33:54 +08:00
Xu Song
f2ae2cc04d Update model_args.py
Former-commit-id: f1e018587e5722e41962abd60f74043a3e55f692
2024-05-31 14:35:48 +08:00
statelesshz
6e9c03f958 Update bug-report.yml
Former-commit-id: a8561502360c1e247eeacb46b77ffbcf3387c482
2024-05-31 13:18:18 +08:00
Uminosachi
2696f614a7 Set scheduler_specific_kwargs to get_scheduler
Former-commit-id: f04e70dfab44480ef4c015c06470443237f69ba9
2024-05-31 13:45:39 +09:00
hiyouga
070b944895 update readme
Former-commit-id: 3b92d8c2ddb288b849f38e573ca168cab23315d2
2024-05-30 16:40:17 +08:00
faddddeout
f5f091d390 fix cann't interrupt training when using multi GPUs in webui
Former-commit-id: a7fb02d52bc202c958490aa7081252be5d9eff50
2024-05-30 08:39:21 +00:00
hiyouga
14ab14a0e6 fix #3837
Former-commit-id: 72965aa3f13a9c085c29781b6790d80d00a545d8
2024-05-30 00:52:26 +08:00
hoshi-hiyouga
4f7c850115 Merge pull request #3829 from seanzhang-zhichen/add_dataset_sample_num
Add dataset sample num

Former-commit-id: ab38cf74ce48ea4f1800e077ca287f2eb9336135
2024-05-30 00:25:45 +08:00
hoshi-hiyouga
391eca66cf Update loader.py
Former-commit-id: 0aa59322906d91c5e385c9c02ebb5dd64ba060f3
2024-05-30 00:20:20 +08:00
hoshi-hiyouga
a67199246d Update loader.py
Former-commit-id: aa7f335e3ad5a78e4ed5f99c120be28e9733ea2e
2024-05-30 00:17:21 +08:00
hoshi-hiyouga
5f67fdaac9 Update loader.py
Former-commit-id: 19d8fd62c18ee3ba0e431fc241f7d315cb716fef
2024-05-30 00:12:12 +08:00
hoshi-hiyouga
05e6fe4287 Update parser.py
Former-commit-id: 310cc11e8c83f16fc5bccc349c38fea347ea9a97
2024-05-30 00:05:20 +08:00
hoshi-hiyouga
91cc571e6e Update README_zh.md
Former-commit-id: 3007d260ed45169583a74497a53b661337dd5f71
2024-05-30 00:04:47 +08:00
hoshi-hiyouga
890926e60c Update README.md
Former-commit-id: 65fb69e388c0a04c15ecd11441e567966f51fae5
2024-05-30 00:04:26 +08:00
hiyouga
87aa332583 better llamaboard
* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
2024-05-29 23:55:38 +08:00
hiyouga
f90c4ca672 fix cohere system
Former-commit-id: 5d629b29e705c8ff8dd4521719d9c0e67a3fe0a2
2024-05-29 20:58:23 +08:00
hiyouga
a922e85a5c fix #3965
Former-commit-id: 37d15ac55d0be0ff47d6a88f07e2d823117a4a36
2024-05-29 20:55:51 +08:00
hiyouga
9a65820592 update readme
Former-commit-id: 440e9de66986ef7736361ce8ec3e23ce68655a56
2024-05-29 18:39:11 +08:00
hoshi-hiyouga
f4e16ae373 Merge pull request #3930 from MengqingCao/npu
Add Ascend npu doc and dependency

Former-commit-id: 7210090e4fc6531b9f6122f104875811a8798185
2024-05-29 18:33:38 +08:00
MengqingCao
e2cfd34da0 update torch-npu version
Former-commit-id: a70d7fcf2967eb30280a1fb845b39db7878f535c
2024-05-29 10:05:11 +00:00
MengqingCao
668dea9706 update cann kernels url
Former-commit-id: 23c65e9d7e8817b5815264e44cbf4a7bcb88d3d7
2024-05-29 09:53:31 +00:00
hoshi-hiyouga
084be442f2 Merge pull request #3958 from hzhaoy/add_telechat_12b_support
add TeleChat-12B/TeleChat-12B-v2 models

Former-commit-id: c228546a09764423ae66966079802022185f7e86
2024-05-29 17:20:53 +08:00
hzhaoy
29cb4a1327 add TeleChat-12B/TeleChat-12B-v2 models
Former-commit-id: e0675385c88af03aaef8d51586c8a282829c4051
2024-05-29 15:00:37 +08:00
hiyouga
81a61134b8 fix hf chat engine
Former-commit-id: 76ce52911690ab0dd8ffa5587127afb4ec942abe
2024-05-29 01:20:07 +08:00
hiyouga
cb1a49aa02 add ds config to webui
Former-commit-id: 66d72b263d36dc81de9f6152077663b613035977
2024-05-29 01:13:17 +08:00
hiyouga
351b4efc6c 10x generate in ppo w/ zero3
https://github.com/huggingface/trl/pull/1483

Former-commit-id: 5dc43ba8b373d8803bc22d88b3d0d95ef8b9c7f8
2024-05-29 00:23:23 +08:00
hiyouga
9b551309de update dpo, kto trainer
Former-commit-id: 4a6cc3c7046f8b27d05ea53ef216bab6fa7ebfaf
2024-05-29 00:14:29 +08:00
hiyouga
9fed4a2ef4 clean kto trainer
Former-commit-id: 76402bd78cbd3a99a544f0ac019468b569b0e1d1
2024-05-28 21:43:26 +08:00
hiyouga
bceac4f554 bump vllm version to 0.4.1
Former-commit-id: a00fd39a4c2f270620711f2bfbad8d460fb4aa89
2024-05-28 21:27:27 +08:00
hiyouga
ae3a88d3a7 update readme
Former-commit-id: bc861f76706df3f643028f1dfc8ec2044b067a08
2024-05-28 19:35:52 +08:00
hiyouga
9138a7a5ba support DDP in webui
Former-commit-id: d059262ff8dc857f597d2657546ec625726a664a
2024-05-28 19:24:22 +08:00
hiyouga
9912b43fcc update readme
Former-commit-id: e2c7de1b5147801b301cfc5da0e2866273da18f5
2024-05-28 16:41:34 +08:00
hiyouga
5ac37555a4 update readme
Former-commit-id: 30ef8ee1e86136f38f105b67f70c417d20552f41
2024-05-28 16:19:56 +08:00
hiyouga
34bdc730a6 fix #3931
Former-commit-id: 47e0072416b545d9718af4fa266a83f747b9a4f7
2024-05-28 13:44:22 +08:00
MengqingCao
e45a9d70fc add Ascend npu doc and dependency
Former-commit-id: 803d9f142a294f8c1e0b4e2046c214b0857ccfd6
2024-05-28 01:33:54 +00:00
hoshi-hiyouga
232b36059c Merge pull request #3925 from Yimi81/feat-fix-yi-template
fix yi template

Former-commit-id: 6caee1eb868b9f7b00578c6608883e89aa232d17
2024-05-27 22:59:32 +08:00
Yimi81
d9fbd675d5 fix yi template
Former-commit-id: b3669c8989c3adda305416245e32e9e5a3b7caac
2024-05-27 13:11:25 +00:00
hiyouga
0206e7b9de tiny fix
Former-commit-id: 4c47b3dcef9e400a1c35fce1ad53619a0a86fe81
2024-05-27 20:54:26 +08:00
hoshi-hiyouga
a886544d3d Merge pull request #3921 from gusye1234/main
Add openchat-3.6-8B support

Former-commit-id: 92e6bba3cab22b7835a68f787caf7992a398978e
2024-05-27 20:52:37 +08:00
hoshi-hiyouga
8c9b929bb0 Update template.py
Former-commit-id: f4dabce0a71c9978e051e70886941b64b928ffe2
2024-05-27 20:51:56 +08:00
hoshi-hiyouga
1bb1ae834e Update template.py
Former-commit-id: af869e4c48eb426c4078415533f6dab89123a9d8
2024-05-27 20:51:26 +08:00
Jianbai Ye
0d9e364a90 add openchat-3.6-8B support
Former-commit-id: b66f39d50d896d7597a1506e67ec210b31c9b700
2024-05-27 20:42:08 +08:00
hiyouga
3b28c003dd fix full/freeze tuning for mllm
Former-commit-id: df5860ddb593d5b82163a585d12160b41dbce0f3
2024-05-27 20:37:57 +08:00
hoshi-hiyouga
48ff9fb150 Merge pull request #3835 from BUAADreamer/main
fix some features in llava-style training

Former-commit-id: fc8583bd17dfb088a52e4d8fa91356b918373b50
2024-05-27 20:23:45 +08:00
hiyouga
c43bc74fe6 support Aya23
Former-commit-id: 071935b90006e2c79e39bb9ee0c5d48c6c910501
2024-05-27 20:23:24 +08:00
BUAADreamer
eaf9cc2195 Merge branch 'hiyouga:main' into main
Former-commit-id: cc1b82bf49b060987392c455fdbfe125ad667ec5
2024-05-27 20:10:58 +08:00
hiyouga
4bd276f58f add llava 1k datasets
Former-commit-id: 345d3355752f4a4dc454696a39f1610fffbbf382
2024-05-27 19:57:33 +08:00
hiyouga
f8cf0d5e5d update dpo examples
Former-commit-id: 69e32a7cb6336ca9a953c379ec794818b3f169bd
2024-05-27 19:56:04 +08:00
BUAADreamer
79bc60db33 Merge branch 'hiyouga:main' into main
Former-commit-id: d89e1f8bf8bad1dd125b4de8fe6c0b2b16411cb5
2024-05-27 19:00:48 +08:00
BUAADreamer
dc7c54067e add only tune lm and mm_proj
Former-commit-id: ba12ca430ec527fbfe4cd1eace0adb5c7712146a
2024-05-27 19:00:15 +08:00
BUAADreamer
932f0d5c20 add regex of only tune lm and mm_proj
Former-commit-id: 38d540b3e69bceabafafab524fcfc78aeb05612d
2024-05-27 18:59:00 +08:00
hiyouga
9670f5e41a add phi-3 7b/14b, mistral v0.3 models
Former-commit-id: 86dab182f9710b063f518922ccb49b01aa71c576
2024-05-27 18:20:16 +08:00
hiyouga
97a23e1cbe update readme
Former-commit-id: b8d0170fe0d094acce85dcb5f91775e4685ee055
2024-05-27 18:14:02 +08:00
BUAADreamer
11fcd055ec Merge branch 'hiyouga:main' into main
Former-commit-id: 113be744b3d044fbea3a8654158aa83ddb4599eb
2024-05-27 11:54:01 +08:00
hiyouga
b0d9966663 support SimPO #3900
Former-commit-id: 6b954ce60155cf8334150b795cfc4bb63ca74c8b
2024-05-26 23:46:33 +08:00
BUAADreamer
5c51ab7e1f Merge branch 'hiyouga:main' into main
Former-commit-id: fd5420c43e1414bcd3fadb6239f4e5d42e6ac10e
2024-05-25 14:18:49 +08:00
hiyouga
26f293d587 fix #3853
Former-commit-id: 465a5500bae1f30744d4b9b3db40aaf9171da2cb
2024-05-24 23:29:45 +08:00
seanzhang-zhichen
a3b52fd380 Merge branch 'main' into add_dataset_sample_num
Former-commit-id: 26300127c45f24e63b91f1b0cc73e46c3a936a91
2024-05-24 15:57:47 +08:00
BUAADreamer
27d8706d6d Merge branch 'hiyouga:main' into main
Former-commit-id: a4ce5ee381fd59f6b254ab634af51b6bb54edd97
2024-05-24 09:50:00 +08:00
hiyouga
bf59383783 refactor data preprocessing, fix mllm rlhf
Former-commit-id: 53ff2dd24f9121ea30c95063bb72e49a9b31e980
2024-05-24 04:08:25 +08:00
hoshi-hiyouga
1078611259 Merge pull request #3876 from dongdongqiang2018/main
added adapted to 910B image

Former-commit-id: 0708cc8a24589b9f22ad3df6685e57d1da0336f2
2024-05-24 01:54:30 +08:00
hiyouga
e6fc0ac8fe fix paligemma sft
requires transformers>=4.41.1


Former-commit-id: 80b3030569cd606ac0de43e9a682478f5bd7b727
2024-05-24 00:23:40 +08:00
hiyouga
554ca3d8dc fix oom issues in export
Former-commit-id: b7ccc882a192aa1e25b1e5816f875ea304282412
2024-05-23 23:32:45 +08:00
donggang
86dfdf956d adapted to 910B image
Former-commit-id: e095254808aace63a1be878620f683902f51cfb3
2024-05-23 09:48:22 +00:00
BUAADreamer
c0e4475485 Merge branch 'hiyouga:main' into main
Former-commit-id: 4076f52c8ba7da4624a1fb3fa52a7170d1c3171e
2024-05-21 22:18:20 +08:00
hiyouga
2b65f8bd5c fix paligemma sft
Former-commit-id: 60682d04414be37e611d6470618a8d599703942b
2024-05-21 20:03:09 +08:00
hiyouga
09e78272c2 Update README_zh.md
Former-commit-id: 34c4ba6bf9bb89170446fb396aa06ae44d251de0
2024-05-21 18:30:59 +08:00
hiyouga
cccce564bd update wechat
Former-commit-id: 6613349562194b48c5fc57aa68e620b8fa83fc0a
2024-05-21 18:22:32 +08:00
hiyouga
4adec327de fix #3847
Former-commit-id: d206b306ca4eadc8b3d4feaf490ad12f9452e562
2024-05-21 17:53:06 +08:00
BUAADreamer
1f093334d1 support pretraining of llava
Former-commit-id: 6a4c8cf0a6a1674c693b9337f018ff8df7477f8f
2024-05-21 08:57:14 +08:00
hiyouga
e0e8507108 support paligemma
Former-commit-id: 11c27f9bf204d3d6a9ca5bd4f0a19a420160453f
2024-05-21 00:01:22 +08:00
hiyouga
f5962f8128 fix paligemma data preprocess
Former-commit-id: 71b85437301739d9d96d3881d4a34b37c0f69db8
2024-05-20 23:51:32 +08:00
hiyouga
b31d808655 fix paligemma inference
Former-commit-id: 46357b7a677e8ba2e0a7c9d4ec1974abd061569c
2024-05-20 23:36:43 +08:00
hiyouga
247cda4b68 fix #3818
Former-commit-id: 3f366e05a34be224f53c5bf8334e57ae5d316004
2024-05-20 21:43:19 +08:00
hiyouga
e30975e9a2 add kto to webui
Former-commit-id: 6c866f4dbd45e868860be8351d1a65c4e1a4e02b
2024-05-20 21:20:25 +08:00
zhangzc
de9f1583c2 fix conflict
Former-commit-id: 6922b23a748c2459147bf44b96d86daa89f2c96c
2024-05-20 17:10:01 +08:00
hiyouga
ab48653e63 fix chat engines
do not use pop(key, default) since api assigns None to dict values


Former-commit-id: 3ebbd0b55ea07de2897c27ca54eeab5c3b319419
2024-05-20 00:36:43 +08:00
hoshi-hiyouga
6d7a1e3f8f Merge pull request #3812 from ycjcl868/feat/chat-support-system-prompt
feat: cli chat support system_message
Former-commit-id: 96596990527403e910c81e95e38bf2638541cf31
2024-05-20 00:31:32 +08:00
hoshi-hiyouga
e093dad7cb Update vllm_engine.py
Former-commit-id: 0b8278bd21baf35d3f60c6ed24f110b391c92a47
2024-05-20 00:31:04 +08:00
hoshi-hiyouga
b103a121f0 Update hf_engine.py
Former-commit-id: ce8b902e538c69d89f207db8a43c85072cd70265
2024-05-20 00:30:45 +08:00
hoshi-hiyouga
3578abc7a4 Update generating_args.py
Former-commit-id: 861c146fa7d9cb5b99372464bd068c20fa36415d
2024-05-20 00:29:31 +08:00
hoshi-hiyouga
17d398f419 Update chat_model.py
Former-commit-id: 7736aafdc81d175e9fb484dbb7cae9263120a0fc
2024-05-20 00:29:12 +08:00
hiyouga
3453a8eebb fix jinja template
Former-commit-id: 353561f0e3914de3f81499c4e4b831ae0a6383b6
2024-05-19 23:38:30 +08:00
ycjcl868
77a089c35c feat: cli chat support system_message
Former-commit-id: e3982bff596d01992733687a580c4f41c558061c
2024-05-19 23:17:46 +08:00
hiyouga
516d83c946 fix zero2 high ram usage
Former-commit-id: 01797126eb173250250e31f8e76b69ae0047745d
2024-05-19 21:53:54 +08:00
hiyouga
fd02c9f973 fix hf gen args
Former-commit-id: 491a84976258cbb2a2647922420e2f84de1e38cd
2024-05-19 19:39:32 +08:00
hiyouga
351e80a656 fix envs
Former-commit-id: d5e150cfb98f8216713415564ab386b8320c88cb
2024-05-19 18:27:18 +08:00
hiyouga
4f04e2ed93 fix #3807
Former-commit-id: 08b695969049de8bf9bd3e90b9700736d90385ee
2024-05-19 17:07:57 +08:00
hiyouga
a810d1b98e update readme
Former-commit-id: e0beb67a417b13c818a09bd419d4e20dd44ca842
2024-05-18 23:09:03 +08:00
hiyouga
fbe963a96a safe output path in webui
Former-commit-id: 23f14262e0d54631630c084ba71e0433ea1d4640
2024-05-18 22:42:28 +08:00
hiyouga
d13b8bee8a fix jetmoe z3 block
Former-commit-id: cb00a14d905395c4b8fadb955f0424a4c56668de
2024-05-18 22:28:45 +08:00
hiyouga
0aa072a155 improve data process logger
Former-commit-id: 33d0b012b56dbafc9fff87b821c2d1bf1409dbb5
2024-05-18 22:02:42 +08:00
hiyouga
57dde7c3bc update data readme
Former-commit-id: 22c7335b496e4a673383d5a1e4e60bf2cb4e35b3
2024-05-18 21:37:38 +08:00
hiyouga
6b9003f781 update data readme
Former-commit-id: beb864a9367943d3274cb6057423d1eb9aaf85c4
2024-05-18 21:15:20 +08:00
hiyouga
9c1c59e481 fix #3803
Former-commit-id: 1ef12c95059d14a1717c82ce04e529e7ad6435ed
2024-05-18 16:13:14 +08:00
hoshi-hiyouga
31daec2749 Merge pull request #3799 from hiyouga/dev
improve KTO impl, replace datasets

Former-commit-id: b4cc207855aa1dbb120f7999165e176e649af338
2024-05-18 03:49:13 +08:00
hiyouga
2bff90719b improve KTO impl., replace datasets
Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
2024-05-18 03:44:56 +08:00
hoshi-hiyouga
e4570e28a8 Merge pull request #3785 from enji-zhou/feature/add_kto
add kto

Former-commit-id: f60faa23e23022fd855dac6b1ecbd21e095bccb5
2024-05-18 03:07:18 +08:00
hoshi-hiyouga
d84a730daa Merge pull request #3794 from jue-jue-zi/main
feat: pass the `max_lora_rank` parameter to vLLM backend
Former-commit-id: be839961686a1845f00a56e398a7b3779df8b6e4
2024-05-17 16:17:30 +08:00
hoshi-hiyouga
0fd1a05cec Update model_args.py
Former-commit-id: f40a2fe5334865763e4d513292d359317b7a091b
2024-05-17 16:16:41 +08:00
juejuezi
6373d307ec feat: pass the max_lora_rank parameter to vLLM backend
Former-commit-id: a8756d839405ecb5deabe885cf11d1a61564deee
2024-05-17 16:07:39 +08:00
hiyouga
a32c3a50fc add deepseek v2 lite model
Former-commit-id: 5e864e6b721d8b891b1cc2ca2dcac41babb9eaaf
2024-05-17 13:25:36 +08:00
enji.zhou
66b5634ebf add kto
Former-commit-id: ec51986cf70b0bdd79b8141e45916670fb97a08e
2024-05-17 13:09:17 +08:00
hiyouga
92b3697e2c update badam example #3764
Former-commit-id: a3730fd0a96bab869be6d695031182dabaea8137
2024-05-17 02:21:10 +08:00
hiyouga
969e605c7e better dtype handle in loading
Former-commit-id: 663f0577dd61a1a31191db2c6fbb0c7cea533b21
2024-05-17 02:14:56 +08:00
hiyouga
a3320f26cf update examples
Former-commit-id: 3b5f138155d96b346bda18e465cf60ec7d99e19c
2024-05-17 01:02:00 +08:00
hiyouga
45329d9e3c enable inbrowser in webui
Former-commit-id: 71fdeedb64b2339eb1c740d670b87e0c03dada68
2024-05-17 00:08:56 +08:00
hiyouga
6481321470 add falcon 11b
Former-commit-id: 897acc725edc204fad393cc9616828431b4fa768
2024-05-17 00:08:33 +08:00
hiyouga
efcf5e050d fix examples #3769
Former-commit-id: 80c036beb8d9ddac8f844f1818c9488ded04e86e
2024-05-16 19:12:09 +08:00
hiyouga
dfa686b617 rename package
Former-commit-id: a07ff0c083558cfe6f474d13027642d3052fee08
2024-05-16 18:39:08 +08:00
hiyouga
fe638cf11f set dev version
Former-commit-id: 5e9c72d07c3793cdccbdb8a9f95f1bb5d714e0a3
2024-05-16 02:17:31 +08:00
zhangzc
7cdc16abdf Supports custom data set sampling quantity
Former-commit-id: fa8325401df27595de4611a89dfcc14644956abd
2024-03-27 14:22:50 +08:00
225 changed files with 7947 additions and 3169 deletions

View File

@@ -4,6 +4,8 @@
.venv .venv
cache cache
data data
hf_cache
output
examples examples
.dockerignore .dockerignore
.gitattributes .gitattributes

View File

@@ -13,6 +13,18 @@ body:
- label: I have read the README and searched the existing issues. - label: I have read the README and searched the existing issues.
required: true required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
placeholder: llamafactory version, platform, python version, ...
- type: textarea - type: textarea
id: reproduction id: reproduction
validations: validations:
@@ -26,7 +38,9 @@ body:
请合理使用 Markdown 标签来格式化您的文本。 请合理使用 Markdown 标签来格式化您的文本。
placeholder: | placeholder: |
python src/train_bash.py ... ```bash
llamafactory-cli train ...
```
- type: textarea - type: textarea
id: expected-behavior id: expected-behavior
@@ -38,18 +52,6 @@ body:
Please provide a clear and concise description of what you would expect to happen. Please provide a clear and concise description of what you would expect to happen.
请提供您原本的目的,即这段代码的期望行为。 请提供您原本的目的,即这段代码的期望行为。
- type: textarea
id: system-info
validations:
required: false
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。
placeholder: transformers version, platform, python version, ...
- type: textarea - type: textarea
id: others id: others
validations: validations:

View File

@@ -5,3 +5,4 @@ Fixes # (issue)
## Before submitting ## Before submitting
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)? - [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
- [ ] Did you write any new necessary tests?

17
.github/workflows/label_issue.yml vendored Normal file
View File

@@ -0,0 +1,17 @@
name: label_issue
on:
issues:
types:
- opened
jobs:
label_issue:
runs-on: ubuntu-latest
steps:
- env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
ISSUE_URL: ${{ github.event.issue.html_url }}
run: |
gh issue edit $ISSUE_URL --add-label "pending"

40
.github/workflows/publish.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: publish
on:
release:
types:
- published
jobs:
publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment:
name: release
url: https://pypi.org/p/llamafactory
permissions:
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build
- name: Build package
run: |
python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -2,28 +2,44 @@ name: tests
on: on:
push: push:
branches: [ "main" ] branches:
- main
paths:
- "**.py"
- "requirements.txt"
- ".github/workflows/*.yml"
pull_request: pull_request:
branches: [ "main" ] branches:
- main
paths:
- "**.py"
- "requirements.txt"
- ".github/workflows/*.yml"
jobs: jobs:
check_code_quality: tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - name: Checkout
uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: "3.8"
cache: "pip"
cache-dependency-path: "setup.py"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ruff python -m pip install .[torch,dev]
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
- name: Test with pytest
run: |
make test

View File

@@ -1,14 +1,47 @@
FROM nvcr.io/nvidia/pytorch:24.01-py3 # Use the NVIDIA official image with PyTorch 2.3.0
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
FROM nvcr.io/nvidia/pytorch:24.02-py3
# Define installation arguments
ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false
ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory
WORKDIR /app WORKDIR /app
# Install the requirements
COPY requirements.txt /app/ COPY requirements.txt /app/
RUN pip install -r requirements.txt RUN pip config set global.index-url $PIP_INDEX
RUN python -m pip install --upgrade pip
RUN python -m pip install -r requirements.txt
# Copy the rest of the application into the image
COPY . /app/ COPY . /app/
RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
# Install the LLaMA Factory
RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_BNB" = "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
fi; \
if [ "$INSTALL_VLLM" = "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
fi; \
if [ "$INSTALL_DEEPSPEED" = "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
pip install -e .[$EXTRA_PACKAGES] && \
pip uninstall -y transformer-engine flash-attn
# Set up volumes
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ] VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
# Expose port 7860 for the LLaMA Board
EXPOSE 7860 EXPOSE 7860
# Expose port 8000 for the API service
EXPOSE 8000
# Launch LLaMA Board
CMD [ "llamafactory-cli", "webui" ] CMD [ "llamafactory-cli", "webui" ]

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
include LICENSE requirements.txt

View File

@@ -1,4 +1,4 @@
.PHONY: quality style .PHONY: quality style test
check_dirs := scripts src tests check_dirs := scripts src tests
@@ -9,3 +9,6 @@ quality:
style: style:
ruff check $(check_dirs) --fix ruff check $(check_dirs) --fix
ruff format $(check_dirs) ruff format $(check_dirs)
test:
CUDA_VISIBLE_DEVICES= pytest tests/

243
README.md
View File

@@ -3,15 +3,15 @@
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-44-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-44-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@@ -26,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89
Choose your path: Choose your path:
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
## Table of Contents ## Table of Contents
@@ -46,9 +47,9 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -70,14 +71,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/05/13] We supported fine-tuning the **Yi-1.5** series models. [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details. [24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
@@ -104,7 +113,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details. [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall`. [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
@@ -142,43 +151,44 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Default module | Template | | Model | Model size | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | | --------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
> >
> Remember to use the **SAME** template in training and inference. > Remember to use the **SAME** template in training and inference.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported. Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
You also can add a custom chat template to [template.py](src/llmtuner/data/template.py). You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
## Supported Training Approaches ## Supported Training Approaches
@@ -189,7 +199,9 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## Provided Datasets ## Provided Datasets
@@ -202,6 +214,8 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
@@ -209,12 +223,12 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
<details><summary>Supervised fine-tuning datasets</summary> <details><summary>Supervised fine-tuning datasets</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json) - [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -223,7 +237,6 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@@ -236,15 +249,19 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) - [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) - [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -260,13 +277,13 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
<details><summary>Preference datasets</summary> <details><summary>Preference datasets</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details> </details>
@@ -281,21 +298,21 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.3.0 |
| transformers | 4.37.2 | 4.40.1 | | transformers | 4.41.2 | 4.41.2 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.27.2 | 0.30.0 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.1 | 0.8.6 | | trl | 0.8.6 | 0.9.4 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.0 | 0.4.2 | | vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.8 | | flash-attn | 2.3.0 | 2.5.9 |
### Hardware Requirement ### Hardware Requirement
@@ -319,12 +336,12 @@ huggingface-cli login
> Installation is mandatory. > Installation is mandatory.
```bash ```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e .[torch,metrics] pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@@ -343,19 +360,35 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Join [NPU user group](assets/wechat_npu.jpg).
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash
# replace the url according to your CANN version and devices
# install CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# install CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| Requirement | Minimum | Recommend | | Requirement | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 | | CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.2.0 | 2.2.0 | | torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.2.0 | 2.2.0 | | torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
Docker image: Docker image:
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
- 64GB: Coming soon - 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
@@ -375,9 +408,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
See [examples/README.md](examples/README.md) for advanced usage (including distributed training). See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
@@ -387,50 +420,38 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
> [!IMPORTANT]
> LLaMA Board GUI only supports training on a single GPU.
#### Use local environment
```bash ```bash
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui llamafactory-cli webui
``` ```
<details><summary>For Alibaba Cloud PAI or AutoDL users</summary> ### Build Docker
If you encountered display problems in LLaMA Board on Alibaba Cloud PAI, try using the following command to set environment variables before starting LLaMA Board:
```bash
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
If you are using AutoDL, please install a specific version of Gradio:
```bash
pip install gradio==4.10.0
```
</details>
#### Use Docker #### Use Docker
```bash ```bash
docker build -f ./Dockerfile -t llama-factory:latest . docker build -f ./Dockerfile \
docker run --gpus=all \ --build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -it --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \ -v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \
--shm-size 16G \ --shm-size 16G \
--name llama_factory \ --name llamafactory \
-d llama-factory:latest llamafactory:latest
``` ```
#### Use Docker Compose #### Use Docker Compose
```bash ```bash
docker compose -f ./docker-compose.yml up -d docker-compose up -d
docker-compose exec llamafactory bash
``` ```
<details><summary>Details about volume</summary> <details><summary>Details about volume</summary>
@@ -444,9 +465,12 @@ docker compose -f ./docker-compose.yml up -d
### Deploy with OpenAI-style API and vLLM ### Deploy with OpenAI-style API and vLLM
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP]
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
### Download from ModelScope Hub ### Download from ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
@@ -455,7 +479,18 @@ If you have trouble with downloading models and datasets from Hugging Face, you
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
``` ```
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
### Use W&B Logger
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
```yaml
report_to: wandb
run_name: test_run # optional
```
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
## Projects using LLaMA Factory ## Projects using LLaMA Factory
@@ -502,7 +537,7 @@ If you have a project that should be incorporated, please contact via email or c
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) 1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
@@ -514,7 +549,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -3,15 +3,15 @@
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-44-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-44-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@@ -26,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
选择你的打开方式: 选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
## 目录 ## 目录
@@ -46,9 +47,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色 ## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练ORPO 训练。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -70,14 +71,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分 [24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。 [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。 [24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
@@ -104,7 +113,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。 [24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall` 即可使模型获得工具调用能力。 [24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall_zh` 即可使模型获得工具调用能力。
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `use_unsloth: true` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `use_unsloth: true` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
@@ -142,43 +151,44 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型 ## 模型
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | | --------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以取得更好的效果 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**
> >
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板** > 请务必在训练和推理时采用**完全一致**的模板。
>
> 请务必在训练和推理时使用**完全一致**的模板。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。 项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。
您也可以在 [template.py](src/llmtuner/data/template.py) 中添加自己的对话模板。 您也可以在 [template.py](src/llamafactory/data/template.py) 中添加自己的对话模板。
## 训练方法 ## 训练方法
@@ -189,7 +199,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## 数据集 ## 数据集
@@ -202,6 +214,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
@@ -209,12 +223,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>指令微调数据集</summary> <details><summary>指令微调数据集</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json) - [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -223,7 +237,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@@ -236,15 +249,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) - [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) - [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -260,13 +277,13 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>偏好数据集</summary> <details><summary>偏好数据集</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details> </details>
@@ -281,21 +298,21 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.3.0 |
| transformers | 4.37.2 | 4.40.1 | | transformers | 4.41.2 | 4.41.2 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.27.2 | 0.30.0 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.1 | 0.8.6 | | trl | 0.8.6 | 0.9.4 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.0 | 0.4.2 | | vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.8 | | flash-attn | 2.3.0 | 2.5.9 |
### 硬件依赖 ### 硬件依赖
@@ -319,12 +336,12 @@ huggingface-cli login
> 此步骤为必需。 > 此步骤为必需。
```bash ```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e .[torch,metrics] pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality 可选的额外依赖项torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@@ -343,21 +360,37 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)** 加入 [NPU 用户群](assets/wechat_npu.jpg)
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
# 安装 CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# 安装 CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| 依赖项 | 至少 | 推荐 | | 依赖项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 | | CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.2.0 | 2.2.0 | | torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.2.0 | 2.2.0 | | torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
Docker 镜像: Docker 镜像:
- 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
- 64GB敬请期待 - 64GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false` 如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
@@ -375,9 +408,9 @@ Docker 镜像:
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。 下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。 高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
@@ -387,50 +420,38 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动) ### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
> [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练。
#### 使用本地环境
```bash ```bash
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui llamafactory-cli webui
``` ```
<details><summary>阿里云 PAI 和 AutoDL 用户指南</summary> ### 构建 Docker
如果您在阿里云 PAI 上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
```bash
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
如果您正在使用 AutoDL请安装下述 Gradio 版本:
```bash
pip install gradio==4.10.0
```
</details>
#### 使用 Docker #### 使用 Docker
```bash ```bash
docker build -f ./Dockerfile -t llama-factory:latest . docker build -f ./Dockerfile \
docker run --gpus=all \ --build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -it --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \ -v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \
--shm-size 16G \ --shm-size 16G \
--name llama_factory \ --name llamafactory \
-d llama-factory:latest llamafactory:latest
``` ```
#### 使用 Docker Compose #### 使用 Docker Compose
```bash ```bash
docker compose -f ./docker-compose.yml up -d docker-compose up -d
docker-compose exec llamafactory bash
``` ```
<details><summary>数据卷详情</summary> <details><summary>数据卷详情</summary>
@@ -444,9 +465,12 @@ docker compose -f ./docker-compose.yml up -d
### 利用 vLLM 部署 OpenAI API ### 利用 vLLM 部署 OpenAI API
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP]
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
### 从魔搭社区下载 ### 从魔搭社区下载
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
@@ -455,7 +479,18 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/l
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
``` ```
`--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct` `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
### 使用 W&B 面板
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
```yaml
report_to: wandb
run_name: test_run # 可选
```
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
@@ -502,7 +537,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) 1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
@@ -514,7 +549,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

View File

@@ -1,16 +1,18 @@
If you are using a custom dataset, please add your **dataset description** to `dataset_info.json` according to the following format. We also provide several examples in the next section. The [dataset_info.json](dataset_info.json) contains all available datasets. If you are using a custom dataset, please **make sure** to add a *dataset description* in `dataset_info.json` and specify `dataset: dataset_name` before training to use it.
Currently we support datasets in **alpaca** and **sharegpt** format.
```json ```json
"dataset_name": { "dataset_name": {
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)", "hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
"ms_hub_url": "the name of the dataset repository on the ModelScope hub. (if specified, ignore script_url and file_name)", "ms_hub_url": "the name of the dataset repository on the Model Scope hub. (if specified, ignore script_url and file_name)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
"file_name": "the name of the dataset file in this directory. (required if above are not specified)", "file_name": "the name of the dataset folder or dataset file in this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
"subset": "the name of the subset. (optional, default: None)", "subset": "the name of the subset. (optional, default: None)",
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)", "num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns (optional)": { "columns (optional)": {
"prompt": "the column name in the dataset containing the prompts. (default: instruction)", "prompt": "the column name in the dataset containing the prompts. (default: instruction)",
"query": "the column name in the dataset containing the queries. (default: input)", "query": "the column name in the dataset containing the queries. (default: input)",
@@ -19,7 +21,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
"messages": "the column name in the dataset containing the messages. (default: conversations)", "messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)", "tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)" "images": "the column name in the dataset containing the image inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
}, },
"tags (optional, used for the sharegpt format)": { "tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)", "role_tag": "the key in the message represents the identity. (default: from)",
@@ -33,28 +38,34 @@ If you are using a custom dataset, please add your **dataset description** to `d
} }
``` ```
After that, you can load the custom dataset by specifying `--dataset dataset_name`. ## Alpaca Format
---- ### Supervised Fine-Tuning Dataset
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format: * [Example dataset](alpaca_en_demo.json)
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response.
The `system` column will be used as the system prompt if specified.
The `history` column is a list consisting of string tuples representing prompt-response pairs in the history messages. Note that the responses in the history **will also be learned by the model** in supervised fine-tuning.
```json ```json
[ [
{ {
"instruction": "user instruction (required)", "instruction": "human instruction (required)",
"input": "user input (optional)", "input": "human input (optional)",
"output": "model response (required)", "output": "model response (required)",
"system": "system prompt (optional)", "system": "system prompt (optional)",
"history": [ "history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"], ["human instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"] ["human instruction in the second round (optional)", "model response in the second round (optional)"]
] ]
} }
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -69,11 +80,11 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
} }
``` ```
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response. ### Pre-training Dataset
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning. - [Example dataset](c4_demo.json)
For the **pre-training datasets**, only the `prompt` column will be used for training, for example: In pre-training, only the `text` column will be used for model learning.
```json ```json
[ [
@@ -82,7 +93,7 @@ For the **pre-training datasets**, only the `prompt` column will be used for tra
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -93,22 +104,24 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
} }
``` ```
For the **preference datasets**, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example: ### Preference Dataset
Preference datasets are used for reward modeling, DPO training and ORPO training.
It requires a better response in `chosen` column and a worse response in `rejected` column.
```json ```json
[ [
{ {
"instruction": "user instruction", "instruction": "human instruction (required)",
"input": "user input", "input": "human input (optional)",
"output": [ "chosen": "chosen answer (required)",
"chosen answer", "rejected": "rejected answer (required)"
"rejected answer"
]
} }
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -117,14 +130,85 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"columns": { "columns": {
"prompt": "instruction", "prompt": "instruction",
"query": "input", "query": "input",
"response": "output", "chosen": "chosen",
"rejected": "rejected"
} }
} }
``` ```
---- ### KTO Dataset
The dataset in **sharegpt** format should follow the below format: - [Example dataset](kto_en_demo.json)
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"kto_tag": "human feedback [true/false] (required)"
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Dataset
- [Example dataset](mllm_demo.json)
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image.
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt Format
### Supervised Fine-Tuning Dataset
- [Example dataset](glaive_toolcall_en_demo.json)
Compared to the alpaca format, the sharegpt format allows the datasets have **more roles**, such as human, gpt, observation and function. They are presented in a list of objects in the `conversations` column.
Note that the human and observation should appear in odd positions, while gpt and function should appear in even positions.
```json ```json
[ [
@@ -132,7 +216,15 @@ The dataset in **sharegpt** format should follow the below format:
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "user instruction" "value": "human instruction"
},
{
"from": "function_call",
"value": "tool arguments"
},
{
"from": "observation",
"value": "tool result"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -145,7 +237,7 @@ The dataset in **sharegpt** format should follow the below format:
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -155,19 +247,63 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"messages": "conversations", "messages": "conversations",
"system": "system", "system": "system",
"tools": "tools" "tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
where the `messages` column should be a list following the `u/a/u/a/u/a` order. ### Preference Dataset
We also supports the dataset in the **openai** format: - [Example dataset](dpo_en_demo.json)
Preference datasets in sharegpt format also require a better message in `chosen` column and a worse message in `rejected` column.
```json
[
{
"conversations": [
{
"from": "human",
"value": "human instruction"
},
{
"from": "gpt",
"value": "model response"
},
{
"from": "human",
"value": "human instruction"
}
],
"chosen": {
"from": "gpt",
"value": "chosen answer (required)"
},
"rejected": {
"from": "gpt",
"value": "rejected answer (required)"
}
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"ranking": true,
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
### OpenAI Format
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
```json ```json
[ [
@@ -179,7 +315,7 @@ We also supports the dataset in the **openai** format:
}, },
{ {
"role": "user", "role": "user",
"content": "user instruction" "content": "human instruction"
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -190,7 +326,7 @@ We also supports the dataset in the **openai** format:
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -209,4 +345,6 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
} }
``` ```
Pre-training datasets and preference datasets are **incompatible** with the sharegpt format yet. The KTO datasets and multimodal datasets in sharegpt format are similar to the alpaca format.
Pre-training datasets are **incompatible** with the sharegpt format.

View File

@@ -1,16 +1,18 @@
如果您使用自定义数据集,请务必按照以下格式`dataset_info.json` 文件中添加**数据集描述**。我们在下面也提供了一些例子 [dataset_info.json](dataset_info.json) 包含了所有可用的数据集。如果您希望使用自定义数据集,请**务必**`dataset_info.json` 文件中添加*数据集描述*,并通过修改 `dataset: 数据集名称` 配置来使用数据集
目前我们支持 **alpaca** 格式和 **sharegpt** 格式的数据集。
```json ```json
"数据集名称": { "数据集名称": {
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name", "hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name", "ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name", "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_name": "该目录下数据集文件夹或文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练", "formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"ranking": "是否为偏好数据集可选默认False",
"subset": "数据集子集的名称可选默认None", "subset": "数据集子集的名称可选默认None",
"folder": "Hugging Face 仓库的文件夹名称可选默认None", "folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集(可选,默认:False", "num_samples": "该数据集中用于训练的样本数量。(可选,默认:None",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns可选": { "columns可选": {
"prompt": "数据集代表提示词的表头名称默认instruction", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input", "query": "数据集代表请求的表头名称默认input",
@@ -19,7 +21,10 @@
"messages": "数据集代表消息列表的表头名称默认conversations", "messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None", "system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None", "tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None" "images": "数据集代表图像输入的表头名称默认None",
"chosen": "数据集代表更优回答的表头名称默认None",
"rejected": "数据集代表更差回答的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
}, },
"tags可选用于 sharegpt 格式)": { "tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from", "role_tag": "消息中代表发送者身份的键名默认from",
@@ -28,22 +33,28 @@
"assistant_tag": "消息中代表助手的 role_tag默认gpt", "assistant_tag": "消息中代表助手的 role_tag默认gpt",
"observation_tag": "消息中代表工具返回结果的 role_tag默认observation", "observation_tag": "消息中代表工具返回结果的 role_tag默认observation",
"function_tag": "消息中代表工具调用的 role_tag默认function_call", "function_tag": "消息中代表工具调用的 role_tag默认function_call",
"system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system " "system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system column"
} }
} }
``` ```
然后,可通过使用 `--dataset 数据集名称` 参数加载自定义数据集。 ## Alpaca 格式
---- ### 指令监督微调数据集
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织: - [样例数据集](alpaca_zh_demo.json)
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
如果指定,`system` 列对应的内容将被作为系统提示词。
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容**也会被用于模型学习**。
```json ```json
[ [
{ {
"instruction": "用户指令(必填)", "instruction": "人类指令(必填)",
"input": "用户输入(选填)", "input": "人类输入(选填)",
"output": "模型回答(必填)", "output": "模型回答(必填)",
"system": "系统提示词(选填)", "system": "系统提示词(选填)",
"history": [ "history": [
@@ -54,7 +65,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -69,11 +80,11 @@
} }
``` ```
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答。 ### 预训练数据集
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。 - [样例数据集](c4_demo.json)
对于**预训练数据集**,仅 `prompt` 列中的内容会用于模型训练,例如: 在预训练时,只有 `text` 列中的内容会用于模型学习。
```json ```json
[ [
@@ -82,7 +93,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -93,22 +104,24 @@
} }
``` ```
对于**偏好数据集**`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如: ### 偏好数据集
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
```json ```json
[ [
{ {
"instruction": "用户指令", "instruction": "人类指令(必填)",
"input": "用户输入", "input": "人类输入(选填)",
"output": [ "chosen": "优质回答(必填)",
"质回答", "rejected": "质回答(必填)"
"劣质回答"
]
} }
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -117,14 +130,85 @@
"columns": { "columns": {
"prompt": "instruction", "prompt": "instruction",
"query": "input", "query": "input",
"response": "output", "chosen": "chosen",
"rejected": "rejected"
} }
} }
``` ```
---- ### KTO 数据集
**sharegpt** 格式的数据集按照以下方式组织: - [样例数据集](kto_en_demo.json)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### 多模态数据集
- [样例数据集](mllm_demo.json)
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt 格式
### 指令监督微调数据集
- [样例数据集](glaive_toolcall_zh_demo.json)
相比 alpaca 格式的数据集sharegpt 格式支持**更多的角色种类**,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 `conversations` 列中。
注意其中 human 和 observation 必须出现在奇数位置gpt 和 function 必须出现在偶数位置。
```json ```json
[ [
@@ -132,7 +216,15 @@
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "用户指令" "value": "人类指令"
},
{
"from": "function_call",
"value": "工具参数"
},
{
"from": "observation",
"value": "工具结果"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -145,7 +237,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -155,19 +247,63 @@
"messages": "conversations", "messages": "conversations",
"system": "system", "system": "system",
"tools": "tools" "tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。 ### 偏好数据集
我们同样支持 **openai** 格式的数据集: - [样例数据集](dpo_zh_demo.json)
Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的消息,并在 `rejected` 列中提供更差的消息。
```json
[
{
"conversations": [
{
"from": "human",
"value": "人类指令"
},
{
"from": "gpt",
"value": "模型回答"
},
{
"from": "human",
"value": "人类指令"
}
],
"chosen": {
"from": "gpt",
"value": "优质回答"
},
"rejected": {
"from": "gpt",
"value": "劣质回答"
}
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"ranking": true,
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
### OpenAI 格式
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
```json ```json
[ [
@@ -179,7 +315,7 @@
}, },
{ {
"role": "user", "role": "user",
"content": "用户指令" "content": "人类指令"
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -190,7 +326,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -209,4 +345,6 @@
} }
``` ```
预训练数据集和偏好数据集**尚不支持** sharegpt 格式 Sharegpt 格式中的 KTO 数据集和多模态数据集与 alpaca 格式的类似
预训练数据集**不支持** sharegpt 格式。

View File

@@ -1 +0,0 @@
3779ddbc040543ab1834ef216c983d6fcc06cc9a

View File

@@ -1 +0,0 @@
a97cf9475291591843976554878568e046d8a46d

View File

@@ -1 +0,0 @@
25508714b7879a1e5a6764ba7f979a980f549f1a

View File

@@ -1 +0,0 @@
7cb6a7d11455bddc3d495750a2392683d775b184

View File

@@ -1 +0,0 @@
f5cb08305ff5dc9c17a09809c54c8c8834aadc70

View File

@@ -1 +0,0 @@
aee47b7b443496e37808d7f34ef10403ff99bcc3

View File

@@ -1,37 +0,0 @@
import json
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""
_URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset):
yield key, example

View File

@@ -1 +0,0 @@
4748dff00d1dc42768a5b6cc772143c313017812

View File

@@ -34,7 +34,8 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
features = datasets.Features( features = datasets.Features(
{ {
"instruction": datasets.Value("string"), "instruction": datasets.Value("string"),
"output": datasets.Sequence(datasets.Value("string")), "chosen": datasets.Value("string"),
"rejected": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))), "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
} }
) )
@@ -79,5 +80,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
break break
prompt = prompt[:human_idx] prompt = prompt[:human_idx]
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history} yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
key += 1 key += 1

View File

@@ -1 +0,0 @@
736bcedea2b24a1414765c6d69cbdafaea839f3c

30
data/wiki_demo.txt Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb

View File

@@ -1,20 +1,25 @@
version: '3.8'
services: services:
llama-factory: llamafactory:
build: build:
dockerfile: Dockerfile dockerfile: Dockerfile
context: . context: .
container_name: llama_factory args:
INSTALL_BNB: false
INSTALL_VLLM: false
INSTALL_DEEPSPEED: false
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes: volumes:
- ./hf_cache:/root/.cache/huggingface/ - ./hf_cache:/root/.cache/huggingface/
- ./data:/app/data - ./data:/app/data
- ./output:/app/output - ./output:/app/output
environment:
- CUDA_VISIBLE_DEVICES=0
ports: ports:
- "7860:7860" - "7860:7860"
- "8000:8000"
ipc: host ipc: host
tty: true
stdin_open: true
command: bash
deploy: deploy:
resources: resources:
reservations: reservations:

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets
@@ -154,7 +155,7 @@ class MMLU(datasets.GeneratorBasedBuilder):
] ]
def _generate_examples(self, filepath): def _generate_examples(self, filepath):
df = pd.read_csv(filepath) df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"] df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")): for i, instance in enumerate(df.to_dict(orient="records")):

View File

@@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
## Table of Contents ## Table of Contents
- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu) - [LoRA Fine-Tuning](#lora-fine-tuning)
- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu) - [QLoRA Fine-Tuning](#qlora-fine-tuning)
- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus) - [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization) - [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models) - [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
- [Extras](#extras) - [Extras](#extras)
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
## Examples ## Examples
### LoRA Fine-Tuning on A Single GPU ### LoRA Fine-Tuning
#### (Continuous) Pre-Training #### (Continuous) Pre-Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
#### Supervised Fine-Tuning #### Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
``` ```
#### Reward Modeling #### Reward Modeling
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
``` ```
#### PPO Training #### PPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO Training #### DPO/ORPO/SimPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
``` ```
#### ORPO Training #### KTO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_orpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
``` ```
#### Preprocess Dataset #### Preprocess Dataset
@@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset. It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
``` ```
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks #### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
``` ```
### QLoRA Fine-Tuning on a Single GPU #### Supervised Fine-Tuning on Multiple Nodes
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
```
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
```
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
```
### LoRA Fine-Tuning on Multiple GPUs
#### Supervised Fine-Tuning with Accelerate on Single Node
```bash
bash examples/lora_multi_gpu/single_node.sh
```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
```bash
bash examples/lora_multi_gpu/multi_node.sh
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash ```bash
bash examples/lora_multi_gpu/ds_zero3.sh FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
### LoRA Fine-Tuning on Multiple NPUs ### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0 #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
```bash ```bash
bash examples/lora_multi_npu/ds_zero0.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
``` ```
### Full-Parameter Fine-Tuning on Multiple GPUs #### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
#### Supervised Fine-Tuning with Accelerate on Single Node
```bash ```bash
bash examples/full_multi_gpu/single_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
``` ```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes #### Supervised Fine-Tuning with 4-bit AWQ Quantization
```bash ```bash
bash examples/full_multi_gpu/multi_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
```
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
```
### Full-Parameter Fine-Tuning
#### Supervised Fine-Tuning on Single Node
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
bash examples/full_multi_gpu/predict.sh llamafactory-cli train examples/train_full/llama3_full_predict.yaml
``` ```
### Merging LoRA Adapters and Quantization ### Merging LoRA Adapters and Quantization
@@ -160,13 +146,13 @@ bash examples/full_multi_gpu/predict.sh
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters. Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
#### Quantizing Model using AutoGPTQ #### Quantizing Model using AutoGPTQ
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
``` ```
### Inferring LoRA Fine-Tuned Models ### Inferring LoRA Fine-Tuned Models
@@ -174,19 +160,19 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.y
#### Use CLI #### Use CLI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### Use Web UI #### Use Web UI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### Launch OpenAI-style API #### Launch OpenAI-style API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### Extras ### Extras
@@ -194,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.
#### Full-Parameter Fine-Tuning using GaLore #### Full-Parameter Fine-Tuning using GaLore
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### Full-Parameter Fine-Tuning using BAdam #### Full-Parameter Fine-Tuning using BAdam
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### LoRA+ Fine-Tuning #### LoRA+ Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
```
#### PiSSA Fine-Tuning
```bash
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
``` ```
#### Mixture-of-Depths Fine-Tuning #### Mixture-of-Depths Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
``` ```
#### LLaMA-Pro Fine-Tuning #### LLaMA-Pro Fine-Tuning
```bash ```bash
bash examples/extras/llama_pro/expand.sh bash examples/extras/llama_pro/expand.sh
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
``` ```
#### FSDP+QLoRA Fine-Tuning #### FSDP+QLoRA Fine-Tuning
```bash ```bash
bash examples/extras/fsdp_qlora/single_node.sh bash examples/extras/fsdp_qlora/train.sh
``` ```

View File

@@ -4,59 +4,59 @@
## 目录 ## 目录
- [单 GPU LoRA 微调](#单-gpu-lora-微调) - [LoRA 微调](#lora-微调)
- [单 GPU QLoRA 微调](#单-gpu-qlora-微调) - [QLoRA 微调](#qlora-微调)
- [多 GPU LoRA 微调](#多-gpu-lora-微调) - [全参数微调](#全参数微调)
- [多 NPU LoRA 微调](#多-npu-lora-微调)
- [多 GPU 全参数微调](#多-gpu-全参数微调)
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化) - [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
- [推理 LoRA 模型](#推理-lora-模型) - [推理 LoRA 模型](#推理-lora-模型)
- [杂项](#杂项) - [杂项](#杂项)
使用 `CUDA_VISIBLE_DEVICES`GPU`ASCEND_RT_VISIBLE_DEVICES`NPU选择计算设备。
## 示例 ## 示例
### 单 GPU LoRA 微调 ### LoRA 微调
#### (增量)预训练 #### (增量)预训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
#### 指令监督微调 #### 指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
``` ```
#### PPO 训练 #### PPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO 训练 #### DPO/ORPO/SimPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
``` ```
#### ORPO 训练 #### KTO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_orpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
``` ```
#### 预处理数据集 #### 预处理数据集
@@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。 对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
``` ```
#### 在 MMLU/CMMLU/C-Eval 上评估 #### 在 MMLU/CMMLU/C-Eval 上评估
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
``` ```
### 单 GPU QLoRA 微调 #### 多机指令监督微调
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
```
#### 基于 4 比特 AWQ 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
```
#### 基于 2 比特 AQLM 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
```
### 多 GPU LoRA 微调
#### 使用 Accelerate 进行单节点训练
```bash
bash examples/lora_multi_gpu/single_node.sh
```
#### 使用 Accelerate 进行多节点训练
```bash
bash examples/lora_multi_gpu/multi_node.sh
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash ```bash
bash examples/lora_multi_gpu/ds_zero3.sh FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
### 多 NPU LoRA 微调 ### QLoRA 微调
#### 使用 DeepSpeed ZeRO-0 训练 #### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
```bash ```bash
bash examples/lora_multi_npu/ds_zero0.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
``` ```
### 多 GPU 全参数微调 #### 基于 4/8 比特 GPTQ 量化进行指令监督微调
#### 使用 DeepSpeed 进行单节点训练
```bash ```bash
bash examples/full_multi_gpu/single_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
``` ```
#### 使用 DeepSpeed 进行多节点训练 #### 基于 4 比特 AWQ 量化进行指令监督微调
```bash ```bash
bash examples/full_multi_gpu/multi_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
```
#### 基于 2 比特 AQLM 量化进行指令监督微调
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
```
### 全参数微调
#### 在单机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### 在多机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
bash examples/full_multi_gpu/predict.sh llamafactory-cli train examples/train_full/llama3_full_predict.yaml
``` ```
### 合并 LoRA 适配器与模型量化 ### 合并 LoRA 适配器与模型量化
@@ -160,13 +146,13 @@ bash examples/full_multi_gpu/predict.sh
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。 注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
#### 使用 AutoGPTQ 量化模型 #### 使用 AutoGPTQ 量化模型
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
``` ```
### 推理 LoRA 模型 ### 推理 LoRA 模型
@@ -174,19 +160,19 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.y
#### 使用命令行接口 #### 使用命令行接口
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### 使用浏览器界面 #### 使用浏览器界面
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### 启动 OpenAI 风格 API #### 启动 OpenAI 风格 API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### 杂项 ### 杂项
@@ -194,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.
#### 使用 GaLore 进行全参数训练 #### 使用 GaLore 进行全参数训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### 使用 BAdam 进行全参数训练 #### 使用 BAdam 进行全参数训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### LoRA+ 微调 #### LoRA+ 微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
```
#### PiSSA 微调
```bash
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
``` ```
#### 深度混合微调 #### 深度混合微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
``` ```
#### LLaMA-Pro 微调 #### LLaMA-Pro 微调
```bash ```bash
bash examples/extras/llama_pro/expand.sh bash examples/extras/llama_pro/expand.sh
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
``` ```
#### FSDP+QLoRA 微调 #### FSDP+QLoRA 微调
```bash ```bash
bash examples/extras/fsdp_qlora/single_node.sh bash examples/extras/fsdp_qlora/train.sh
``` ```

View File

@@ -5,16 +5,16 @@ downcast_bf16: 'no'
fsdp_config: fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false fsdp_forward_prefetch: false
fsdp_offload_params: true fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: true # offload may affect training speed
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true fsdp_sync_module_states: true
fsdp_use_orig_params: false fsdp_use_orig_params: true
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16 # or bf16
num_machines: 1 # the number of nodes num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static

View File

@@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,16 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1 # the number of nodes
num_processes: 4 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 1
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,41 +1,41 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
use_badam: true use_badam: true
badam_switch_mode: descending badam_switch_mode: ascending
badam_switch_interval: 50 badam_switch_interval: 50
badam_verbose: 2 badam_verbose: 2
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,42 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# ddp ### dataset
ddp_timeout: 180000000 dataset: identity,alpaca_en_demo
# dataset
dataset: identity,alpaca_gpt4_en
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,10 +1,6 @@
#!/bin/bash #!/bin/bash
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA # DO NOT use GPTQ/AWQ model in FSDP+QLoRA
pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file examples/accelerate/fsdp_config.yaml \ --config_file examples/accelerate/fsdp_config.yaml \
src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml

View File

@@ -1,7 +1,7 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
@@ -11,32 +11,32 @@ galore_target: mlp,self_attn
galore_rank: 128 galore_rank: 128
galore_scale: 2.0 galore_scale: 2.0
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,7 +1,7 @@
# model ### model
model_name_or_path: models/llama3-8b-instruct-pro model_name_or_path: models/llama3-8b-instruct-pro
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: freeze finetuning_type: freeze
@@ -9,32 +9,33 @@ freeze_trainable_layers: 8
freeze_trainable_modules: all freeze_trainable_modules: all
use_llama_pro: true use_llama_pro: true
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b-instruct-pro/freeze/sft output_dir: saves/llama3-8b-instruct-pro/freeze/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
loraplus_lr_ratio: 16.0 loraplus_lr_ratio: 16.0
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
mixture_of_depths: convert mixture_of_depths: convert
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b-mod/full/sft output_dir: saves/llama3-8b-mod/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
optim: paged_adamw_8bit optim: paged_adamw_8bit
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,41 +1,42 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
pissa_init: true
pissa_iter: 4
pissa_convert: true
# ddp ### dataset
ddp_timeout: 180000000 dataset: identity,alpaca_en_demo
# dataset
dataset: identity,alpaca_gpt4_en
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=2
RANK=0
MASTER_ADDR=192.168.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/full_multi_gpu/llama3_full_predict.yaml

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml

View File

@@ -1,6 +0,0 @@
#!/bin/bash
# also launch it on slave machine using slave_config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/master_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml

View File

@@ -1,8 +1,8 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3 template: llama3
# export ### export
export_dir: models/llama3_gptq export_dir: models/llama3_gptq
export_quantization_bit: 4 export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.json export_quantization_dataset: data/c4_demo.json

View File

@@ -1,12 +1,12 @@
# Note: DO NOT use quantized model or quantization_bit when merging lora adapters ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3 template: llama3
finetuning_type: lora finetuning_type: lora
# export ### export
export_dir: models/llama3_lora_sft export_dir: models/llama3_lora_sft
export_size: 2 export_size: 2
export_device: cpu export_device: cpu

View File

@@ -1,23 +1,23 @@
# model ### model
model_name_or_path: saves/llama3-8b/full/sft model_name_or_path: saves/llama3-8b/full/sft
# method ### method
stage: sft stage: sft
do_predict: true do_predict: true
finetuning_type: full finetuning_type: full
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 50 max_samples: 50
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/predict output_dir: saves/llama3-8b/full/predict
overwrite_output_dir: true overwrite_output_dir: true
# eval ### eval
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
predict_with_generate: true predict_with_generate: true

View File

@@ -1,41 +1,39 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
# ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,41 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: dpo stage: dpo
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
dpo_ftx: 1.0 pref_beta: 0.1
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
# dataset ### dataset
dataset: orca_rlhf dataset: dpo_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/dpo output_dir: saves/llama3-8b/lora/dpo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 5.0e-6
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,19 +1,19 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
# method ### method
finetuning_type: lora finetuning_type: lora
# dataset ### dataset
task: mmlu task: mmlu
split: test split: test
template: fewshot template: fewshot
lang: en lang: en
n_shot: 5 n_shot: 5
# output ### output
save_dir: saves/llama3-8b/lora/eval save_dir: saves/llama3-8b/lora/eval
# eval ### eval
batch_size: 4 batch_size: 4

View File

@@ -1,38 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: orpo stage: kto
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
pref_beta: 0.1
# dataset ### dataset
dataset: orca_rlhf dataset: kto_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/orpo output_dir: saves/llama3-8b/lora/kto
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 5.0e-6
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,39 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
reward_model: saves/llama3-8b/lora/reward reward_model: saves/llama3-8b/lora/reward
# method ### method
stage: ppo stage: ppo
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/ppo output_dir: saves/llama3-8b/lora/ppo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# generate ### generate
max_new_tokens: 512 max_new_tokens: 512
top_k: 0 top_k: 0
top_p: 0.9 top_p: 0.9

View File

@@ -1,24 +1,25 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
# method ### method
stage: sft stage: sft
do_predict: true do_predict: true
finetuning_type: lora finetuning_type: lora
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 50 max_samples: 50
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/predict output_dir: saves/llama3-8b/lora/predict
overwrite_output_dir: true overwrite_output_dir: true
# eval ### eval
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
predict_with_generate: true predict_with_generate: true
ddp_timeout: 180000000

View File

@@ -1,37 +1,38 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: pt stage: pt
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: c4_demo dataset: c4_demo
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,39 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: rm stage: rm
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: orca_rlhf dataset: dpo_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/reward output_dir: saves/llama3-8b/lora/reward
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,39 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,42 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json deepspeed: examples/deepspeed/ds_z0_config.json
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,42 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,14 +1,14 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
@@ -16,6 +16,6 @@ overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
tokenized_path: saves/llama3-8b/dataset/sft tokenized_path: saves/llama3-8b/dataset/sft
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
overwrite_output_dir: true overwrite_output_dir: true

View File

@@ -1,14 +1,14 @@
# model ### model
model_name_or_path: llava-hf/llava-1.5-7b-hf model_name_or_path: llava-hf/llava-1.5-7b-hf
visual_inputs: true visual_inputs: true
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: mllm_demo dataset: mllm_demo
template: vicuna template: vicuna
cutoff_len: 1024 cutoff_len: 1024
@@ -16,24 +16,25 @@ max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llava1_5-7b/lora/sft output_dir: saves/llava1_5-7b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,39 @@
# model ### model
model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16 model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,39 @@
# model ### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,39 @@
# model ### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -13,7 +13,7 @@ select = ["C", "E", "F", "I", "W"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["llmtuner"] known-first-party = ["llamafactory"]
known-third-party = [ known-third-party = [
"accelerate", "accelerate",
"datasets", "datasets",

View File

@@ -1,12 +1,14 @@
transformers>=4.37.2 transformers>=4.41.2
datasets>=2.14.3 datasets>=2.16.0
accelerate>=0.27.2 accelerate>=0.30.1
peft>=0.10.0 peft>=0.11.1
trl>=0.8.1 trl>=0.8.6
gradio>=4.0.0 gradio>=4.0.0
pandas>=2.0.0
scipy scipy
einops einops
sentencepiece sentencepiece
tiktoken
protobuf protobuf
uvicorn uvicorn
pydantic pydantic

View File

@@ -1,14 +1,27 @@
# coding=utf-8 # coding=utf-8
# Calculates the flops of pre-trained models. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 #
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ # This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fire import fire
import torch import torch
from deepspeed.accelerator import get_accelerator # type: ignore from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner.chat import ChatModel from llamafactory.chat import ChatModel
def calculate_flops( def calculate_flops(
@@ -17,6 +30,10 @@ def calculate_flops(
seq_length: int = 256, seq_length: int = 256,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r"""
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with get_accelerator().device(0): with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)

View File

@@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. # Copyright 2024 imoneoi and the LlamaFactory team.
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 #
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py # This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from typing import Literal from typing import Literal
@@ -12,10 +25,10 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llamafactory.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args from llamafactory.hparams import get_train_args
from llmtuner.model import load_tokenizer from llamafactory.model import load_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@@ -32,6 +45,10 @@ def calculate_lr(
cutoff_len: int = 1024, # i.e. maximum input length during training cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate, is_mistral: bool = False, # mistral model uses a smaller learning rate,
): ):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
"""
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage=stage, stage=stage,

View File

@@ -1,6 +1,17 @@
# coding=utf-8 # coding=utf-8
# Calculates the ppl on the dataset of the pre-trained models. # Copyright 2024 the LlamaFactory team.
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from dataclasses import dataclass from dataclasses import dataclass
@@ -12,10 +23,10 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llamafactory.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args from llamafactory.hparams import get_train_args
from llmtuner.model import load_model, load_tokenizer from llamafactory.model import load_model, load_tokenizer
@dataclass @dataclass
@@ -56,6 +67,10 @@ def cal_ppl(
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict( dict(
stage=stage, stage=stage,

View File

@@ -1,15 +1,26 @@
# coding=utf-8 # coding=utf-8
# Calculates the distribution of the input lengths in the dataset. # Copyright 2024 the LlamaFactory team.
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict from collections import defaultdict
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from llmtuner.data import get_dataset from llamafactory.data import get_dataset
from llmtuner.hparams import get_train_args from llamafactory.hparams import get_train_args
from llmtuner.model import load_tokenizer from llamafactory.model import load_tokenizer
def length_cdf( def length_cdf(
@@ -19,6 +30,10 @@ def length_cdf(
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r"""
Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
"""
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage="sft", stage="sft",

View File

@@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 #
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py # This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@@ -37,6 +50,10 @@ def block_expansion(
shard_size: Optional[str] = "2GB", shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False, save_safetensors: Optional[bool] = False,
): ):
r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
num_layers = getattr(config, "num_hidden_layers") num_layers = getattr(config, "num_hidden_layers")
setattr(config, "num_hidden_layers", num_layers + num_expand) setattr(config, "num_hidden_layers", num_layers + num_expand)
@@ -103,11 +120,11 @@ def block_expansion(
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print("Model weights saved in {}".format(output_dir))
print("Fine-tune this model with:") print("- Fine-tune this model with:")
print(" --model_name_or_path {} \\".format(output_dir)) print("model_name_or_path: {}".format(output_dir))
print(" --finetuning_type freeze \\") print("finetuning_type: freeze")
print(" --freeze_trainable_layers {} \\".format(num_expand)) print("freeze_trainable_layers: {}".format(num_expand))
print(" --use_llama_pro") print("use_llama_pro: true")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,8 +1,17 @@
# coding=utf-8 # coding=utf-8
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. # Copyright 2024 the LlamaFactory team.
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output #
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py # Licensed under the Apache License, Version 2.0 (the "License");
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2( def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
): ):
r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
"""
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@@ -1,7 +1,17 @@
# coding=utf-8 # coding=utf-8
# Converts the Qwen models in the same format as LLaMA2. # Copyright 2024 the LlamaFactory team.
# Usage: python llamafy_qwen.py --input_dir input --output_dir output #
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen( def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
): ):
r"""
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@@ -1,14 +1,25 @@
# coding=utf-8 # coding=utf-8
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir #
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py # This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import fire import fire
import torch
import torch.nn as nn
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -17,38 +28,21 @@ if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
class Shell(nn.Module):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = nn.Parameter(bias, requires_grad=False)
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent_module = model.get_submodule(parent_name)
child_module = getattr(parent_module, child_name)
base_layer = getattr(child_module, "base_layer")
weight = getattr(base_layer, "weight", None)
bias = getattr(base_layer, "bias", None)
setattr(parent_module, child_name, Shell(weight, bias))
print("Model unwrapped.")
def quantize_loftq( def quantize_loftq(
model_name_or_path: str, model_name_or_path: str,
save_dir: str, output_dir: str,
loftq_bits: Optional[int] = 4, loftq_bits: int = 4,
loftq_iter: Optional[int] = 1, loftq_iter: int = 4,
lora_alpha: Optional[int] = None, lora_alpha: int = None,
lora_rank: Optional[int] = 16, lora_rank: int = 16,
lora_target: Optional[str] = "q_proj,v_proj", lora_dropout: float = 0,
save_safetensors: Optional[bool] = False, lora_target: str = "q_proj,v_proj",
save_safetensors: bool = True,
): ):
r"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter) loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
@@ -57,25 +51,34 @@ def quantize_loftq(
inference_mode=True, inference_mode=True,
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=0.1, lora_dropout=lora_dropout,
target_modules=[name.strip() for name in lora_target.split(",")], target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="loftq", init_lora_weights="loftq",
loftq_config=loftq_config, loftq_config=loftq_config,
) )
# Init LoftQ model # Init LoftQ model
lora_model = get_peft_model(model, lora_config) print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
base_model: "PreTrainedModel" = lora_model.get_base_model() peft_model = get_peft_model(model, lora_config)
loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model # Save LoftQ model
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir) setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True) setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir))
# Save base model # Save base model
unwrap_model(base_model) base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(save_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(loftq_dir))
print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits))
if __name__ == "__main__": if __name__ == "__main__":

82
scripts/pissa_init.py Normal file
View File

@@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import fire
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedModel
def quantize_pissa(
model_name_or_path: str,
output_dir: str,
pissa_iter: int = 4,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
lora_target: str = "q_proj,v_proj",
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
)
# Init PiSSA model
peft_model = get_peft_model(model, lora_config)
pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir))
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(pissa_dir))
print("finetuning_type: lora")
print("pissa_init: false")
print("pissa_convert: true")
print("- and optionally with:")
print("quantization_bit: 4")
if __name__ == "__main__":
fire.Fire(quantize_pissa)

View File

@@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
from typing import Sequence from typing import Sequence
@@ -20,7 +35,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
def main(): def main():
client = OpenAI( client = OpenAI(
api_key="0", api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
) )
tools = [ tools = [

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import re import re
@@ -5,7 +19,7 @@ from setuptools import find_packages, setup
def get_version(): def get_version():
with open(os.path.join("src", "llmtuner", "cli.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
@@ -21,24 +35,25 @@ def get_requires():
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"], "deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"vllm": ["vllm>=0.4.0"], "vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam"], "badam": ["badam"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["tiktoken", "transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"quality": ["ruff"], "dev": ["ruff", "pytest"],
} }
def main(): def main():
setup( setup(
name="llmtuner", name="llamafactory",
version=get_version(), version=get_version(),
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
@@ -53,7 +68,7 @@ def main():
python_requires=">=3.8.0", python_requires=">=3.8.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llmtuner.cli:main"]}, entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@@ -1,9 +1,23 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import uvicorn import uvicorn
from llmtuner.api.app import create_app from llamafactory.api.app import create_app
from llmtuner.chat import ChatModel from llamafactory.chat import ChatModel
def main(): def main():

View File

@@ -0,0 +1,20 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION
__version__ = VERSION

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional

View File

@@ -1,10 +1,27 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import json import json
import os
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
ChatCompletionMessage, ChatCompletionMessage,
@@ -25,7 +42,17 @@ if is_fastapi_available():
from fastapi import HTTPException, status from fastapi import HTTPException, status
if is_pillow_available():
from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@@ -40,7 +67,9 @@ ROLE_MAPPING = {
} }
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0: if len(request.messages) == 0:
@@ -49,12 +78,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content system = request.messages.pop(0).content
else: else:
system = "" system = None
if len(request.messages) % 2 == 0: if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -62,10 +92,27 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name tool_calls = [
arguments = message.tool_calls[0].function.arguments {"name": tool_call.function.name, "argument": tool_call.function.arguments}
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) for tool_call in message.tool_calls
]
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb")
else: # web uri
image_path = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@@ -73,12 +120,12 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception: except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:
tools = "" tools = None
return input_messages, system, tools return input_messages, system, tools, image
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@@ -97,11 +144,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request) input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@@ -114,15 +162,17 @@ async def create_chat_completion_response(
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
if tools: if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text) result = chat_model.engine.template.extract_tool(response.response_text)
else: else:
result = response.response_text result = response.response_text
if isinstance(result, tuple): if isinstance(result, list):
name, arguments = result tool_calls = []
function = Function(name=name, arguments=arguments) for tool in result:
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) function = Function(name=tool[0], arguments=tool[1])
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
else: else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
@@ -145,7 +195,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request) input_messages, system, tools, image = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@@ -159,6 +209,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,

View File

@@ -0,0 +1,34 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@@ -56,9 +70,19 @@ class FunctionCall(BaseModel):
function: Function function: Function
class ImageURL(BaseModel):
url: str
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[str] = None content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[List[FunctionCall]] = None

View File

@@ -0,0 +1,19 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_engine import BaseEngine
from .chat_model import ChatModel
__all__ = ["BaseEngine", "ChatModel"]

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
@@ -36,11 +50,6 @@ class BaseEngine(ABC):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ... ) -> None: ...
@abstractmethod
async def start(
self,
) -> None: ...
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,

View File

@@ -1,3 +1,20 @@
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
@@ -14,7 +31,7 @@ if TYPE_CHECKING:
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None: def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_forever() loop.run_forever()
@@ -32,7 +49,6 @@ class ChatModel:
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start() self._thread.start()
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
def chat( def chat(
self, self,

View File

@@ -1,13 +1,28 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import os import os
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -23,6 +38,9 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
def __init__( def __init__(
self, self,
@@ -41,6 +59,14 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab ) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
@staticmethod @staticmethod
def _process_args( def _process_args(
@@ -55,47 +81,69 @@ class HuggingfaceEngine(BaseEngine):
image: Optional["NDArray"] = None, image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]: if (
messages[0]["content"] = "<image>" + messages[0]["content"] processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and template.image_token not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
) )
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"]) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", generating_args["temperature"]) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", generating_args["top_p"]) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", generating_args["top_k"]) top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"]) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"]) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
raise ValueError("Stop parameter is not supported in Huggingface engine yet.") logger.warning("Stop parameter is not supported in Huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
dict( dict(
do_sample=do_sample, do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature, temperature=temperature if temperature is not None else generating_args["temperature"],
top_p=top_p, top_p=top_p if top_p is not None else generating_args["top_p"],
top_k=top_k, top_k=top_k if top_k is not None else generating_args["top_k"],
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty
length_penalty=length_penalty, if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
) )
) )
if isinstance(num_return_sequences, int) and num_return_sequences > 1: if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
generating_args["do_sample"] = True generating_args["do_sample"] = True
generating_args["temperature"] = generating_args["temperature"] or 1.0
if not generating_args["temperature"]:
generating_args["do_sample"] = False
if not generating_args["do_sample"]: if not generating_args["do_sample"]:
generating_args.pop("temperature", None) generating_args.pop("temperature", None)
@@ -111,14 +159,13 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs = dict( gen_kwargs = dict(
inputs=inputs, inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args), generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if processor is not None and image is not None: if pixel_values is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") gen_kwargs["pixel_values"] = pixel_values
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@@ -220,9 +267,6 @@ class HuggingfaceEngine(BaseEngine):
return scores return scores
async def start(self) -> None:
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@@ -247,7 +291,7 @@ class HuggingfaceEngine(BaseEngine):
image, image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args) return await loop.run_in_executor(pool, self._chat, *input_args)
@@ -275,7 +319,7 @@ class HuggingfaceEngine(BaseEngine):
image, image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args) stream = self._stream_chat(*input_args)
while True: while True:
@@ -294,6 +338,6 @@ class HuggingfaceEngine(BaseEngine):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs) input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self._semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args) return await loop.run_in_executor(pool, self._get_scores, *input_args)

View File

@@ -1,23 +1,40 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from numpy.typing import NDArray from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
@@ -36,8 +53,6 @@ class VllmEngine(BaseEngine):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
infer_dtype = str(infer_dtype).split(".")[-1]
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
@@ -51,7 +66,7 @@ class VllmEngine(BaseEngine):
"model": model_args.model_name_or_path, "model": model_args.model_name_or_path,
"trust_remote_code": True, "trust_remote_code": True,
"download_dir": model_args.cache_dir, "download_dir": model_args.cache_dir,
"dtype": infer_dtype, "dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen, "max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1, "tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util, "gpu_memory_utilization": model_args.vllm_gpu_util,
@@ -59,6 +74,7 @@ class VllmEngine(BaseEngine):
"disable_log_requests": True, "disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager, "enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None, "enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
} }
if model_args.visual_inputs: if model_args.visual_inputs:
@@ -66,11 +82,10 @@ class VllmEngine(BaseEngine):
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2 self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values" engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>") engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
# bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828
import vllm.model_executor.models.llava import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
@@ -91,27 +106,52 @@ class VllmEngine(BaseEngine):
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"] if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
) )
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search = self.generating_args["num_beams"] > 1 use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"]) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"]) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"]) top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"]) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"]) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"]
if max_length: if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1 max_tokens = max_length - prompt_length if max_length > prompt_length else 1
@@ -120,38 +160,29 @@ class VllmEngine(BaseEngine):
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=num_return_sequences, n=num_return_sequences,
repetition_penalty=repetition_penalty, repetition_penalty=(
temperature=temperature, repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
top_p=top_p, )
top_k=top_k, or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search, use_beam_search=use_beam_search,
length_penalty=length_penalty, length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens, max_tokens=max_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
prompt=None, inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request, lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
) )
return result_generator return result_generator
async def start(self) -> None:
pass
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],

View File

@@ -1,9 +1,30 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import subprocess
import sys import sys
from enum import Enum, unique from enum import Enum, unique
from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@@ -23,8 +44,6 @@ USAGE = (
+ "-" * 70 + "-" * 70
) )
VERSION = "0.7.1"
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
@@ -37,11 +56,14 @@ WELCOME = (
+ "-" * 58 + "-" * 58
) )
logger = get_logger(__name__)
@unique @unique
class Command(str, Enum): class Command(str, Enum):
API = "api" API = "api"
CHAT = "chat" CHAT = "chat"
ENV = "env"
EVAL = "eval" EVAL = "eval"
EXPORT = "export" EXPORT = "export"
TRAIN = "train" TRAIN = "train"
@@ -57,11 +79,34 @@ def main():
run_api() run_api()
elif command == Command.CHAT: elif command == Command.CHAT:
run_chat() run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL: elif command == Command.EVAL:
run_eval() run_eval()
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
else:
run_exp() run_exp()
elif command == Command.WEBDEMO: elif command == Command.WEBDEMO:
run_web_demo() run_web_demo()

View File

@@ -0,0 +1,30 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
]

View File

@@ -1,20 +1,42 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features from datasets import Features
from .utils import Role from ..extras.logging import get_logger
from .data_utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = [] outputs = []
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for image in images: for image in images:
@@ -29,6 +51,9 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
def convert_alpaca( def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])): for i in range(len(examples[dataset_attr.prompt])):
@@ -45,21 +70,32 @@ def convert_alpaca(
if dataset_attr.query and examples[dataset_attr.query][i]: if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i]) content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list): if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else: else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = [] response = []
outputs["prompt"].append(prompt) outputs["prompt"].append(prompt)
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs return outputs
@@ -68,6 +104,9 @@ def convert_alpaca(
def convert_sharegpt( def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = { tag_mapping = {
@@ -87,21 +126,62 @@ def convert_sharegpt(
else: else:
system = examples[dataset_attr.system][i] if dataset_attr.system else "" system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0: if len(messages) == 0:
continue continue
aligned_messages = [] aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages)) logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append( aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
) )
outputs["prompt"].append(aligned_messages[:-1]) if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
outputs["response"].append(aligned_messages[-1:]) dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system) outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
@@ -110,7 +190,10 @@ def convert_sharegpt(
def align_dataset( def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
@@ -143,7 +226,7 @@ def align_dataset(
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset", desc="Converting format of dataset",
) )

View File

@@ -0,0 +1,95 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
target_feature = {
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags)
return batch

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union
@@ -10,7 +24,7 @@ if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)

Some files were not shown because too many files have changed in this diff Show More