Compare commits
535 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2299e261b | ||
|
|
8a44dce326 | ||
|
|
6d9233833b | ||
|
|
d019603835 | ||
|
|
478e8194d9 | ||
|
|
1890d3dafe | ||
|
|
522a3e8493 | ||
|
|
18968405d0 | ||
|
|
71a1c1321a | ||
|
|
cf58a6d860 | ||
|
|
9adc0a2c3f | ||
|
|
16419b2834 | ||
|
|
82a2bac866 | ||
|
|
151ef48b40 | ||
|
|
a255c3a476 | ||
|
|
f4ec4fa6ad | ||
|
|
2635794727 | ||
|
|
d2f845d70d | ||
|
|
bb8aba5abf | ||
|
|
9f16c50155 | ||
|
|
25bb9f5ad9 | ||
|
|
7b985f55db | ||
|
|
fd0357a26d | ||
|
|
31f9daa362 | ||
|
|
15ea576246 | ||
|
|
19a6916d80 | ||
|
|
585c475f71 | ||
|
|
e62dae37fe | ||
|
|
11672f760d | ||
|
|
b9f84900ee | ||
|
|
5f65558088 | ||
|
|
0f54a78144 | ||
|
|
2986bef530 | ||
|
|
065f7fb5da | ||
|
|
c1d5073bd3 | ||
|
|
ee46011b34 | ||
|
|
d55f420206 | ||
|
|
fcf75633a0 | ||
|
|
e77ced045d | ||
|
|
331f53381f | ||
|
|
1d675a287d | ||
|
|
be33ef67fb | ||
|
|
f5cd17881e | ||
|
|
c09b648934 | ||
|
|
f2fd9d1b25 | ||
|
|
167342af8a | ||
|
|
76f9bd1820 | ||
|
|
a893505924 | ||
|
|
ed25e051a9 | ||
|
|
5e5fc337f9 | ||
|
|
58e9ca8aa0 | ||
|
|
a4c4b8496f | ||
|
|
38c9641777 | ||
|
|
8b8fdb3a85 | ||
|
|
290057069e | ||
|
|
46203856fc | ||
|
|
80b89978d9 | ||
|
|
5a221d91f9 | ||
|
|
3a3f4072e5 | ||
|
|
0c0cdc26bc | ||
|
|
2581cc844b | ||
|
|
d58fcd094e | ||
|
|
86063e27ea | ||
|
|
88eafd865b | ||
|
|
3f7bd98bfa | ||
|
|
b72c4bd118 | ||
|
|
808ff89a2d | ||
|
|
6d7f1299bd | ||
|
|
0420a608ca | ||
|
|
2047eab723 | ||
|
|
e11b40c344 | ||
|
|
b869506a57 | ||
|
|
72d5b06b08 | ||
|
|
94726bdc8d | ||
|
|
4d1791e905 | ||
|
|
528e06ccaa | ||
|
|
fec641ec82 | ||
|
|
8f401e37f8 | ||
|
|
9feb78e7b4 | ||
|
|
c2022431aa | ||
|
|
0817c24c04 | ||
|
|
cfb926fb84 | ||
|
|
34746d6151 | ||
|
|
5bb447b118 | ||
|
|
a28261a866 | ||
|
|
800de98dc8 | ||
|
|
222423bcef | ||
|
|
e71737351f | ||
|
|
4f298894da | ||
|
|
a8fae3869d | ||
|
|
db9b977e4f | ||
|
|
87d685b59f | ||
|
|
e4046bdd1f | ||
|
|
5baa3add8c | ||
|
|
332f637592 | ||
|
|
31daa6570b | ||
|
|
33525a34b6 | ||
|
|
3607caa2ad | ||
|
|
0fc2e19279 | ||
|
|
ef994600db | ||
|
|
7638f1070e | ||
|
|
c2120432db | ||
|
|
66184762e8 | ||
|
|
41a9e231cb | ||
|
|
1bb06e06df | ||
|
|
381f7120e6 | ||
|
|
f7857c83e1 | ||
|
|
d0da6f40b0 | ||
|
|
28d145a066 | ||
|
|
ae32c148d1 | ||
|
|
2a05941b14 | ||
|
|
11c38b9173 | ||
|
|
73c1c15b62 | ||
|
|
7f58bf984f | ||
|
|
ec552372ba | ||
|
|
17d32fb5c7 | ||
|
|
4b61610b12 | ||
|
|
07798e4aad | ||
|
|
6d6acd0213 | ||
|
|
a789e0f263 | ||
|
|
f9ee00b6b6 | ||
|
|
31bfdb08cd | ||
|
|
12c83e00fc | ||
|
|
9dc7b6c7ac | ||
|
|
627548bf7f | ||
|
|
dc65ecdf09 | ||
|
|
e577990eb2 | ||
|
|
1f3b729a4b | ||
|
|
0aa7ac210f | ||
|
|
40382f1387 | ||
|
|
75b3819e43 | ||
|
|
e63c2df0b1 | ||
|
|
25d4889789 | ||
|
|
8c0a721c4c | ||
|
|
9e972bc9ec | ||
|
|
1675712a4c | ||
|
|
e0c9012f7f | ||
|
|
a25024bd0c | ||
|
|
867980196e | ||
|
|
4e25d037c8 | ||
|
|
6ba6926221 | ||
|
|
b6b53b61f7 | ||
|
|
647c51a772 | ||
|
|
3b843ac9d4 | ||
|
|
0ef1f981da | ||
|
|
944a2aec4d | ||
|
|
4f31ad997c | ||
|
|
8683582300 | ||
|
|
5ccc607222 | ||
|
|
d8bd46f1bf | ||
|
|
8c2a712247 | ||
|
|
53e41bf2c7 | ||
|
|
0eeae9061c | ||
|
|
08729dbefc | ||
|
|
2c120aa0df | ||
|
|
cca6286b6f | ||
|
|
8516054e4d | ||
|
|
d1a8cd67d2 | ||
|
|
8a5b4bdfd4 | ||
|
|
3bceef02ee | ||
|
|
166a830938 | ||
|
|
18767fe026 | ||
|
|
18a1a4b9da | ||
|
|
6015fe700e | ||
|
|
369dae8dd3 | ||
|
|
2aaf3697d7 | ||
|
|
5504b5254c | ||
|
|
b2e4f11602 | ||
|
|
e3f95abca7 | ||
|
|
2f44f70c2c | ||
|
|
f8f05a883b | ||
|
|
5f473e2696 | ||
|
|
88b1874c04 | ||
|
|
58bc6943dc | ||
|
|
2dedf7b401 | ||
|
|
5769a553d2 | ||
|
|
552816e04b | ||
|
|
b5fa1044b8 | ||
|
|
3c55976a0e | ||
|
|
4611f67fae | ||
|
|
a5346041bb | ||
|
|
df42e438c1 | ||
|
|
7dbfd7dff6 | ||
|
|
a897d46049 | ||
|
|
adff887659 | ||
|
|
eba78f2159 | ||
|
|
ec05c8cdb4 | ||
|
|
0a869c4ed4 | ||
|
|
f792eaf8d4 | ||
|
|
8a41c96761 | ||
|
|
e5d9d8c55d | ||
|
|
3e44c8fe3a | ||
|
|
925e421bde | ||
|
|
bbb636bdba | ||
|
|
a30bdbb1c0 | ||
|
|
95b7e10a06 | ||
|
|
0385c60177 | ||
|
|
44895ebe36 | ||
|
|
44dfbf9dbd | ||
|
|
0a465fc3ca | ||
|
|
01eeae50b5 | ||
|
|
7eeeffdb8a | ||
|
|
eca06531c3 | ||
|
|
d90b40b60f | ||
|
|
1898c1e9a6 | ||
|
|
8d2f8b0dd8 | ||
|
|
df42281256 | ||
|
|
896cf476d5 | ||
|
|
37961d5f06 | ||
|
|
bb047bc844 | ||
|
|
448adedf6a | ||
|
|
469c7cd462 | ||
|
|
ebf6a07681 | ||
|
|
53f0fff513 | ||
|
|
ab7567693d | ||
|
|
1b8aab0723 | ||
|
|
30ebe61914 | ||
|
|
6f1c8dacea | ||
|
|
8881237475 | ||
|
|
584755be4b | ||
|
|
3d3324be5c | ||
|
|
4196d5b4d6 | ||
|
|
101c95ce65 | ||
|
|
19ebc0e7a2 | ||
|
|
1ce15b5d9e | ||
|
|
d670d62a66 | ||
|
|
6522467ddb | ||
|
|
aacd9642f5 | ||
|
|
4446c92517 | ||
|
|
8c65548b10 | ||
|
|
fb22651faf | ||
|
|
cfff136b2a | ||
|
|
bac2c64f87 | ||
|
|
be1ec97c8e | ||
|
|
bbd432415d | ||
|
|
1fef702382 | ||
|
|
39865d8a1f | ||
|
|
c7b27bd70b | ||
|
|
86e4fab0d5 | ||
|
|
ff3e40e4a5 | ||
|
|
ea830cad0c | ||
|
|
225e270fd5 | ||
|
|
c1768cfb14 | ||
|
|
53edd62f8b | ||
|
|
41a7e128b6 | ||
|
|
6b8c41c3ac | ||
|
|
2f09c34980 | ||
|
|
76dc69ce36 | ||
|
|
6c9d05539a | ||
|
|
b6bc17f730 | ||
|
|
c07ba8ccc0 | ||
|
|
ed86f621a0 | ||
|
|
c6a3175bbf | ||
|
|
452291417d | ||
|
|
ab9db8b7c7 | ||
|
|
877e2ea791 | ||
|
|
6ea42d5b63 | ||
|
|
31c117e696 | ||
|
|
04f057334f | ||
|
|
99a54d06ca | ||
|
|
8332c85f37 | ||
|
|
fcf1a3df62 | ||
|
|
f4f52ae67d | ||
|
|
0b08d5882a | ||
|
|
62eeafaba6 | ||
|
|
5a52e41399 | ||
|
|
e8083f8f3f | ||
|
|
338b3a03f0 | ||
|
|
c8b01b41ac | ||
|
|
6d08a418ed | ||
|
|
e3066d1489 | ||
|
|
487e3f2507 | ||
|
|
b82a53cad8 | ||
|
|
5bec82ca9d | ||
|
|
57354fc990 | ||
|
|
89f240805c | ||
|
|
27bbea886c | ||
|
|
3ec3dda33a | ||
|
|
ae9f338bf7 | ||
|
|
bf44f76dc7 | ||
|
|
c18581f0a4 | ||
|
|
9f6c5c4798 | ||
|
|
7bc03ac986 | ||
|
|
85d7e4f4ab | ||
|
|
bf69747f40 | ||
|
|
f1146bf7b6 | ||
|
|
9efd1fec90 | ||
|
|
3b91839a55 | ||
|
|
bc4421eeef | ||
|
|
5003820a6a | ||
|
|
cd2485f28d | ||
|
|
918a367378 | ||
|
|
3d35aeca72 | ||
|
|
53b1e5fd1d | ||
|
|
b852c895cf | ||
|
|
aaa7ed8712 | ||
|
|
205aca5b03 | ||
|
|
87b1f851f1 | ||
|
|
fca814b30d | ||
|
|
a20c2b6ecf | ||
|
|
fee94e1c54 | ||
|
|
047a596542 | ||
|
|
3d45606984 | ||
|
|
310c107d56 | ||
|
|
089e4d9e96 | ||
|
|
ae56c3cf49 | ||
|
|
0a0288a286 | ||
|
|
25da686758 | ||
|
|
e2da3cc9fa | ||
|
|
c42e5cf401 | ||
|
|
9943cd1c96 | ||
|
|
1e6f96508a | ||
|
|
d401974f69 | ||
|
|
09b2dbe859 | ||
|
|
7f8ef8c132 | ||
|
|
fcb6283a72 | ||
|
|
0027f46ccc | ||
|
|
967a27695e | ||
|
|
3ce8a326c6 | ||
|
|
91b56b7baf | ||
|
|
e2fa961302 | ||
|
|
87d6d7dc61 | ||
|
|
00019e2ca4 | ||
|
|
b104739d63 | ||
|
|
6ef0d13e42 | ||
|
|
b238d1aa04 | ||
|
|
aa497d5d96 | ||
|
|
fecf04b2f4 | ||
|
|
3f157e2f6f | ||
|
|
c7c558562e | ||
|
|
c2ea5fb618 | ||
|
|
fa9c32bb8d | ||
|
|
c610deb5a2 | ||
|
|
2bb3255e74 | ||
|
|
b28b74c71e | ||
|
|
1ed921bff7 | ||
|
|
80f634cc95 | ||
|
|
a3eb5e200c | ||
|
|
2d02c0e22d | ||
|
|
093eda2ad6 | ||
|
|
dbaf621f57 | ||
|
|
ceb701c2d4 | ||
|
|
29ad3783f5 | ||
|
|
fa2386e73c | ||
|
|
e0045e8386 | ||
|
|
b94c941196 | ||
|
|
ba66ac084f | ||
|
|
83479c9ef0 | ||
|
|
df8ac15ef0 | ||
|
|
8cea5cd967 | ||
|
|
a2d7d6a518 | ||
|
|
a63e624eca | ||
|
|
8596c321ce | ||
|
|
54cd799aa0 | ||
|
|
8185eb1890 | ||
|
|
03213984ec | ||
|
|
aeeee9d4b5 | ||
|
|
c8a1fb99bf | ||
|
|
f0181a41ff | ||
|
|
f6b06d0c6f | ||
|
|
1047217f78 | ||
|
|
16a9a44849 | ||
|
|
58fb24ce41 | ||
|
|
a9afffa246 | ||
|
|
1fdd053022 | ||
|
|
0a833968a0 | ||
|
|
58b681de78 | ||
|
|
22d5fc5f4c | ||
|
|
cc0119f698 | ||
|
|
580cedebde | ||
|
|
43bd1b070c | ||
|
|
42aa9c65be | ||
|
|
b0b87fa33f | ||
|
|
22912eba1a | ||
|
|
e2748fa967 | ||
|
|
248d5daaff | ||
|
|
8f5921692e | ||
|
|
e880eb8844 | ||
|
|
dc076c4e52 | ||
|
|
8306e93ef3 | ||
|
|
6a2cd129c0 | ||
|
|
30d7f6a22e | ||
|
|
5440ebbae6 | ||
|
|
22dbe694e9 | ||
|
|
64ac6ca396 | ||
|
|
377d37fa7f | ||
|
|
55296744a8 | ||
|
|
d0889012c2 | ||
|
|
3a8b2890eb | ||
|
|
5b2284a51d | ||
|
|
4807d8a4ef | ||
|
|
c6e1313977 | ||
|
|
66819fd3ee | ||
|
|
bd85e370be | ||
|
|
cc097174cc | ||
|
|
7d135bbdb8 | ||
|
|
4845a76535 | ||
|
|
67645c0db8 | ||
|
|
f463b3f038 | ||
|
|
01defc2779 | ||
|
|
c9e77ab352 | ||
|
|
c3de160d1c | ||
|
|
3693d7b571 | ||
|
|
a63144c28f | ||
|
|
2b3b0473cd | ||
|
|
9d929897ce | ||
|
|
313a5e1494 | ||
|
|
74dd25224a | ||
|
|
c7efc7f2ed | ||
|
|
c71c78da50 | ||
|
|
f4897da009 | ||
|
|
a6951db970 | ||
|
|
9d27aaa38f | ||
|
|
3b19b6f31b | ||
|
|
5b15ca0b0b | ||
|
|
aad79127e6 | ||
|
|
c42dcab32b | ||
|
|
be519c84d9 | ||
|
|
b2dc6dc59a | ||
|
|
9df626dc18 | ||
|
|
8d4b9200a1 | ||
|
|
7806df46ba | ||
|
|
bba026a212 | ||
|
|
6e111eb29f | ||
|
|
2b69ae0eb2 | ||
|
|
13d73574ef | ||
|
|
bc264807ae | ||
|
|
f9815dd20a | ||
|
|
1f58943b32 | ||
|
|
6476507429 | ||
|
|
35862d19ec | ||
|
|
1272cb00df | ||
|
|
e9ac26db4c | ||
|
|
20ee1d2e19 | ||
|
|
cbc1dd0c88 | ||
|
|
870bbabbc4 | ||
|
|
8fd84c375e | ||
|
|
32b5364051 | ||
|
|
cf72aec098 | ||
|
|
87849d12d2 | ||
|
|
a19512436f | ||
|
|
6c89d93aea | ||
|
|
345f40a660 | ||
|
|
8b9a814653 | ||
|
|
05fabf9095 | ||
|
|
95eede911a | ||
|
|
7bc7f7d673 | ||
|
|
054fdbe186 | ||
|
|
f0f80819a0 | ||
|
|
e702678252 | ||
|
|
553579986a | ||
|
|
622cb04f27 | ||
|
|
f3ba11a432 | ||
|
|
8b1f53bca5 | ||
|
|
ac25fef80e | ||
|
|
15f819d273 | ||
|
|
f2d1c43d28 | ||
|
|
464acc7d6c | ||
|
|
a96c5da737 | ||
|
|
28d09b81c9 | ||
|
|
a769d0e3d4 | ||
|
|
1b98b5e65c | ||
|
|
3cc5408da7 | ||
|
|
689f5c4554 | ||
|
|
ab5d042cd3 | ||
|
|
4d43317aa1 | ||
|
|
ed3b0c5b40 | ||
|
|
67a97794ee | ||
|
|
2c7c93cb9b | ||
|
|
4d4fe08d14 | ||
|
|
85a919b6f7 | ||
|
|
fe2abe20fc | ||
|
|
12444720db | ||
|
|
510faf5805 | ||
|
|
722e01c8ab | ||
|
|
6050e6cff9 | ||
|
|
c8abbe4fc3 | ||
|
|
f2881c9d4a | ||
|
|
1ded3abdf1 | ||
|
|
e641f1215a | ||
|
|
ca736bcab7 | ||
|
|
bddb2646bd | ||
|
|
e4c57f54f8 | ||
|
|
6de82ca843 | ||
|
|
b2c02df555 | ||
|
|
ca86d6361e | ||
|
|
b6fb00e046 | ||
|
|
86c84972c8 | ||
|
|
9390927875 | ||
|
|
c4a585f232 | ||
|
|
300feb3245 | ||
|
|
cacafb0038 | ||
|
|
6509114259 | ||
|
|
7d4cb79822 | ||
|
|
b867e164fe | ||
|
|
26bbfc084d | ||
|
|
c376eed31d | ||
|
|
7c595abc38 | ||
|
|
c428ab68d8 | ||
|
|
968b9f1852 | ||
|
|
018266c66e | ||
|
|
111c644bf1 | ||
|
|
ed5c641e8b | ||
|
|
de72d1f0e7 | ||
|
|
8bfb856923 | ||
|
|
8fdbaab95d | ||
|
|
a01668bbe8 | ||
|
|
3385616a37 | ||
|
|
1f0d89328d | ||
|
|
a7feab45d5 | ||
|
|
f34322afd7 | ||
|
|
3815fa40b7 | ||
|
|
c43050b3fa | ||
|
|
3e152872ad | ||
|
|
ae6ad55758 | ||
|
|
0118a2fc04 | ||
|
|
4dd81976f4 | ||
|
|
2b4da8baf6 | ||
|
|
7d1b4071e8 | ||
|
|
8fc5377f50 | ||
|
|
e5812f261d | ||
|
|
f7e85cd7de | ||
|
|
749395420b | ||
|
|
7d536d1d75 | ||
|
|
7fd0d2fc2f | ||
|
|
ec696bbcdd | ||
|
|
df24345d65 | ||
|
|
386dd26097 | ||
|
|
514f976cc1 | ||
|
|
66b870fd08 | ||
|
|
24d3c7e378 | ||
|
|
484128b641 | ||
|
|
588ea95732 | ||
|
|
800567cde7 | ||
|
|
7a3ba5a25d |
@@ -7,6 +7,8 @@ data
|
|||||||
docker
|
docker
|
||||||
saves
|
saves
|
||||||
hf_cache
|
hf_cache
|
||||||
|
ms_cache
|
||||||
|
om_cache
|
||||||
output
|
output
|
||||||
.dockerignore
|
.dockerignore
|
||||||
.gitattributes
|
.gitattributes
|
||||||
|
|||||||
23
.env.local
23
.env.local
@@ -1,35 +1,40 @@
|
|||||||
# Note: actually we do not support .env, just for reference
|
# Note: actually we do not support .env, just for reference
|
||||||
# api
|
# api
|
||||||
API_HOST=0.0.0.0
|
API_HOST=
|
||||||
API_PORT=8000
|
API_PORT=
|
||||||
API_KEY=
|
API_KEY=
|
||||||
API_MODEL_NAME=gpt-3.5-turbo
|
API_MODEL_NAME=
|
||||||
|
API_VERBOSE=
|
||||||
FASTAPI_ROOT_PATH=
|
FASTAPI_ROOT_PATH=
|
||||||
|
MAX_CONCURRENT=
|
||||||
# general
|
# general
|
||||||
DISABLE_VERSION_CHECK=
|
DISABLE_VERSION_CHECK=
|
||||||
FORCE_CHECK_IMPORTS=
|
FORCE_CHECK_IMPORTS=
|
||||||
FORCE_TORCHRUN=
|
ALLOW_EXTRA_ARGS=
|
||||||
LLAMAFACTORY_VERBOSITY=
|
LLAMAFACTORY_VERBOSITY=
|
||||||
USE_MODELSCOPE_HUB=
|
USE_MODELSCOPE_HUB=
|
||||||
|
USE_OPENMIND_HUB=
|
||||||
|
USE_RAY=
|
||||||
RECORD_VRAM=
|
RECORD_VRAM=
|
||||||
# torchrun
|
# torchrun
|
||||||
FORCE_TORCHRUN=
|
FORCE_TORCHRUN=
|
||||||
MASTER_ADDR=
|
MASTER_ADDR=
|
||||||
MASTER_PORT=
|
MASTER_PORT=
|
||||||
NNODES=
|
NNODES=
|
||||||
RANK=
|
NODE_RANK=
|
||||||
NPROC_PER_NODE=
|
NPROC_PER_NODE=
|
||||||
# wandb
|
# wandb
|
||||||
WANDB_DISABLED=
|
WANDB_DISABLED=
|
||||||
WANDB_PROJECT=huggingface
|
WANDB_PROJECT=
|
||||||
WANDB_API_KEY=
|
WANDB_API_KEY=
|
||||||
# gradio ui
|
# gradio ui
|
||||||
GRADIO_SHARE=False
|
GRADIO_SHARE=
|
||||||
GRADIO_SERVER_NAME=0.0.0.0
|
GRADIO_SERVER_NAME=
|
||||||
GRADIO_SERVER_PORT=
|
GRADIO_SERVER_PORT=
|
||||||
GRADIO_ROOT_PATH=
|
GRADIO_ROOT_PATH=
|
||||||
|
GRADIO_IPV6=
|
||||||
# setup
|
# setup
|
||||||
ENABLE_SHORT_CONSOLE=1
|
ENABLE_SHORT_CONSOLE=
|
||||||
# reserved (do not use)
|
# reserved (do not use)
|
||||||
LLAMABOARD_ENABLED=
|
LLAMABOARD_ENABLED=
|
||||||
LLAMABOARD_WORKDIR=
|
LLAMABOARD_WORKDIR=
|
||||||
|
|||||||
46
.github/CONTRIBUTING.md
vendored
46
.github/CONTRIBUTING.md
vendored
@@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
|
|||||||
### Style guide
|
### Style guide
|
||||||
|
|
||||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||||
|
|
||||||
|
### Create a Pull Request
|
||||||
|
|
||||||
|
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||||
|
|
||||||
|
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:[username]/LLaMA-Factory.git
|
||||||
|
cd LLaMA-Factory
|
||||||
|
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create a new branch to hold your development changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git checkout -b dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Set up a development environment by running the following command in a virtual environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
|
||||||
|
|
||||||
|
5. Check code before commit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make commit
|
||||||
|
make style && make quality
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Submit changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add .
|
||||||
|
git commit -m "commit message"
|
||||||
|
git fetch upstream
|
||||||
|
git rebase upstream/main
|
||||||
|
git push -u origin dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).
|
||||||
|
|||||||
63
.github/ISSUE_TEMPLATE/1-bug-report.yml
vendored
Normal file
63
.github/ISSUE_TEMPLATE/1-bug-report.yml
vendored
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
name: "\U0001F41B Bug / help"
|
||||||
|
description: Create a report to help us improve the LLaMA Factory
|
||||||
|
labels: ["bug", "pending"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Issues included in **[FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** or those with **insufficient** information may be closed without a response.
|
||||||
|
已经包含在 **[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** 内或提供信息**不完整**的 issues 可能不会被回复。
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead.
|
||||||
|
请勿在此分类下创建和框架 bug 无关的 issues,请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: reminder
|
||||||
|
attributes:
|
||||||
|
label: Reminder
|
||||||
|
description: |
|
||||||
|
Please ensure you have read the above rules carefully and searched the existing issues (including FAQs).
|
||||||
|
请确保您已经认真阅读了上述规则并且搜索过现有的 issues(包括常见问题)。
|
||||||
|
|
||||||
|
options:
|
||||||
|
- label: I have read the above rules and searched the existing issues.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: system-info
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: System Info
|
||||||
|
description: |
|
||||||
|
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
|
||||||
|
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
|
||||||
|
|
||||||
|
placeholder: llamafactory version, platform, python version, ...
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduction
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: Reproduction
|
||||||
|
description: |
|
||||||
|
Please provide entry arguments, error messages and stack traces that reproduces the problem.
|
||||||
|
请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
|
||||||
|
Remember to wrap your log messages with \`\`\`.
|
||||||
|
请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。
|
||||||
|
|
||||||
|
value: |
|
||||||
|
```text
|
||||||
|
Put your message here.
|
||||||
|
```
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: others
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
attributes:
|
||||||
|
label: Others
|
||||||
41
.github/ISSUE_TEMPLATE/2-feature-request.yml
vendored
Normal file
41
.github/ISSUE_TEMPLATE/2-feature-request.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
name: "\U0001F680 Feature request"
|
||||||
|
description: Submit a request for a new feature
|
||||||
|
labels: ["enhancement", "pending"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Please do not create issues that are not related to new features under this category.
|
||||||
|
请勿在此分类下创建和新特性无关的 issues。
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: reminder
|
||||||
|
attributes:
|
||||||
|
label: Reminder
|
||||||
|
description: |
|
||||||
|
Please ensure you have read the above rules carefully and searched the existing issues.
|
||||||
|
请确保您已经认真阅读了上述规则并且搜索过现有的 issues。
|
||||||
|
|
||||||
|
options:
|
||||||
|
- label: I have read the above rules and searched the existing issues.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: Description
|
||||||
|
description: |
|
||||||
|
A clear and concise description of the feature proposal.
|
||||||
|
请详细描述您希望加入的新功能特性。
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: contribution
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
attributes:
|
||||||
|
label: Pull Request
|
||||||
|
description: |
|
||||||
|
Have you already created the relevant PR and submitted the code?
|
||||||
|
您是否已经创建了相关 PR 并提交了代码?
|
||||||
66
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
66
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,66 +0,0 @@
|
|||||||
name: "\U0001F41B Bug / Help"
|
|
||||||
description: Create a report to help us improve the LLaMA Factory
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: |
|
|
||||||
Issues included in **FAQs** or those with **insufficient** information may be closed without a response.
|
|
||||||
包含在**常见问题**内或提供信息**不完整**的 issues 可能不会被回复。
|
|
||||||
|
|
||||||
- type: checkboxes
|
|
||||||
id: reminder
|
|
||||||
attributes:
|
|
||||||
label: Reminder
|
|
||||||
description: |
|
|
||||||
Please ensure you have read the README carefully and searched the existing issues (including FAQs).
|
|
||||||
请确保您已经认真阅读了 README 并且搜索过现有的 issues(包括常见问题)。
|
|
||||||
|
|
||||||
options:
|
|
||||||
- label: I have read the README and searched the existing issues.
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: system-info
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
attributes:
|
|
||||||
label: System Info
|
|
||||||
description: |
|
|
||||||
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
|
|
||||||
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
|
|
||||||
|
|
||||||
placeholder: llamafactory version, platform, python version, ...
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: reproduction
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
attributes:
|
|
||||||
label: Reproduction
|
|
||||||
description: |
|
|
||||||
Please provide code snippets, error messages and stack traces that reproduces the problem.
|
|
||||||
请提供运行参数,错误信息以及异常堆栈以便于我们复现该问题。
|
|
||||||
Remember to use Markdown tags to correctly format your code.
|
|
||||||
请合理使用 Markdown 标签来格式化您的文本。
|
|
||||||
|
|
||||||
placeholder: |
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train ...
|
|
||||||
```
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: expected-behavior
|
|
||||||
validations:
|
|
||||||
required: false
|
|
||||||
attributes:
|
|
||||||
label: Expected behavior
|
|
||||||
description: |
|
|
||||||
Please provide a clear and concise description of what you would expect to happen.
|
|
||||||
请提供您原本的目的,即这段代码的期望行为。
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: others
|
|
||||||
validations:
|
|
||||||
required: false
|
|
||||||
attributes:
|
|
||||||
label: Others
|
|
||||||
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
blank_issues_enabled: false
|
||||||
8
.github/workflows/label_issue.yml
vendored
8
.github/workflows/label_issue.yml
vendored
@@ -18,13 +18,15 @@ jobs:
|
|||||||
ISSUE_URL: ${{ github.event.issue.html_url }}
|
ISSUE_URL: ${{ github.event.issue.html_url }}
|
||||||
ISSUE_TITLE: ${{ github.event.issue.title }}
|
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||||
run: |
|
run: |
|
||||||
LABEL=pending
|
LABEL=""
|
||||||
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
|
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
|
||||||
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
|
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
|
||||||
for KEYWORD in ${NPU_KEYWORDS[@]}; do
|
for KEYWORD in ${NPU_KEYWORDS[@]}; do
|
||||||
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
|
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
|
||||||
LABEL=pending,npu
|
LABEL="npu"
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
gh issue edit $ISSUE_URL --add-label $LABEL
|
if [ -n "$LABEL" ]; then
|
||||||
|
gh issue edit $ISSUE_URL --add-label $LABEL
|
||||||
|
fi
|
||||||
|
|||||||
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.8"
|
python-version: "3.9"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -22,10 +22,10 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
- "3.8"
|
|
||||||
- "3.9"
|
- "3.9"
|
||||||
- "3.10"
|
- "3.10"
|
||||||
- "3.11"
|
- "3.11"
|
||||||
|
- "3.12"
|
||||||
os:
|
os:
|
||||||
- "ubuntu-latest"
|
- "ubuntu-latest"
|
||||||
- "windows-latest"
|
- "windows-latest"
|
||||||
@@ -33,9 +33,6 @@ jobs:
|
|||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
environment:
|
|
||||||
name: tests
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
OS_NAME: ${{ matrix.os }}
|
OS_NAME: ${{ matrix.os }}
|
||||||
@@ -54,7 +51,6 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install git+https://github.com/huggingface/transformers.git
|
|
||||||
python -m pip install ".[torch,dev]"
|
python -m pip install ".[torch,dev]"
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
|
|||||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -159,11 +159,20 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
# vscode
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# uv
|
||||||
|
uv.lock
|
||||||
|
|
||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
ms_cache/
|
ms_cache/
|
||||||
hf_cache/
|
hf_cache/
|
||||||
|
om_cache/
|
||||||
cache/
|
cache/
|
||||||
config/
|
config/
|
||||||
saves/
|
saves/
|
||||||
output/
|
output/
|
||||||
wandb/
|
wandb/
|
||||||
|
swanlog/
|
||||||
|
generated_predictions.jsonl
|
||||||
|
|||||||
28
.pre-commit-config.yaml
Normal file
28
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ['--maxkb=25000']
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
args: [--markdown-linebreak-ext=md]
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
args: ['--branch', 'main']
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.17.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: [--py38-plus]
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.9
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
- id: ruff-format
|
||||||
11
Makefile
11
Makefile
@@ -1,7 +1,14 @@
|
|||||||
.PHONY: quality style test
|
.PHONY: build commit quality style test
|
||||||
|
|
||||||
check_dirs := scripts src tests setup.py
|
check_dirs := scripts src tests setup.py
|
||||||
|
|
||||||
|
build:
|
||||||
|
pip install build && python -m build
|
||||||
|
|
||||||
|
commit:
|
||||||
|
pre-commit install
|
||||||
|
pre-commit run --all-files
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
ruff format --check $(check_dirs)
|
ruff format --check $(check_dirs)
|
||||||
@@ -11,4 +18,4 @@ style:
|
|||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
||||||
|
|||||||
325
README.md
325
README.md
@@ -1,19 +1,31 @@
|
|||||||

|

|
||||||
|
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||||
[](LICENSE)
|
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
|
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||||
|
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#projects-using-llama-factory)
|
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
|
[](https://gitcode.com/zhengyaowei/LLaMA-Factory)
|
||||||
|
|
||||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
[](https://trendshift.io/repositories/4535)
|
<h3 align="center">
|
||||||
|
Easily fine-tune 100+ large language models with zero-code <a href="#quickstart">CLI</a> and <a href="#fine-tuning-with-llama-board-gui-powered-by-gradio">Web UI</a>
|
||||||
|
</h3>
|
||||||
|
<p align="center">
|
||||||
|
<picture>
|
||||||
|
<img alt="Github trend" src="https://trendshift.io/api/badge/repositories/4535">
|
||||||
|
</picture>
|
||||||
|
</p>
|
||||||
|
|
||||||
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
|
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
|
||||||
|
|
||||||
@@ -25,10 +37,14 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
|
|||||||
|
|
||||||
Choose your path:
|
Choose your path:
|
||||||
|
|
||||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
- **Documentation**: https://llamafactory.readthedocs.io/en/latest/
|
||||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
- **Local machine**: Please refer to [usage](#getting-started)
|
- **Local machine**: Please refer to [usage](#getting-started)
|
||||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
- **PAI-DSW (free trial)**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
|
||||||
|
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
@@ -40,6 +56,16 @@ Choose your path:
|
|||||||
- [Provided Datasets](#provided-datasets)
|
- [Provided Datasets](#provided-datasets)
|
||||||
- [Requirement](#requirement)
|
- [Requirement](#requirement)
|
||||||
- [Getting Started](#getting-started)
|
- [Getting Started](#getting-started)
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Data Preparation](#data-preparation)
|
||||||
|
- [Quickstart](#quickstart)
|
||||||
|
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
|
||||||
|
- [Build Docker](#build-docker)
|
||||||
|
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
|
||||||
|
- [Download from ModelScope Hub](#download-from-modelscope-hub)
|
||||||
|
- [Download from Modelers Hub](#download-from-modelers-hub)
|
||||||
|
- [Use W&B Logger](#use-wb-logger)
|
||||||
|
- [Use SwanLab Logger](#use-swanlab-logger)
|
||||||
- [Projects using LLaMA Factory](#projects-using-llama-factory)
|
- [Projects using LLaMA Factory](#projects-using-llama-factory)
|
||||||
- [License](#license)
|
- [License](#license)
|
||||||
- [Citation](#citation)
|
- [Citation](#citation)
|
||||||
@@ -47,14 +73,22 @@ Choose your path:
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
||||||
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
|
||||||
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
|
||||||
|
### Day-N Support for Fine-Tuning Cutting-Edge Models
|
||||||
|
|
||||||
|
| Support Date | Model Name |
|
||||||
|
| ------------ | ---------------------------------------------------------- |
|
||||||
|
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
|
||||||
|
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
|
||||||
|
|
||||||
## Benchmark
|
## Benchmark
|
||||||
|
|
||||||
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
||||||
@@ -72,17 +106,41 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
|
||||||
|
|
||||||
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
|
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
|
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
|
||||||
|
|
||||||
|
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
|
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
|
||||||
|
|
||||||
|
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
|
||||||
|
|
||||||
|
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
||||||
|
|
||||||
|
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
|
||||||
|
|
||||||
|
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
|
||||||
|
|
||||||
|
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||||
|
|
||||||
|
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
||||||
|
|
||||||
|
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
||||||
|
|
||||||
|
[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
|
||||||
|
|
||||||
|
[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
|
||||||
|
|
||||||
|
[24/07/04] We supported [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
|
||||||
|
|
||||||
|
[24/06/16] We supported **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
||||||
|
|
||||||
@@ -128,7 +186,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
||||||
|
|
||||||
@@ -160,34 +218,51 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model | Model size | Template |
|
| Model | Model size | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ------------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
|
||||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
|
||||||
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
|
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||||
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
|
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||||
|
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||||
|
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||||
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
|
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
|
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||||
|
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl |
|
||||||
|
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
||||||
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||||
@@ -271,9 +346,13 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||||
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
|
||||||
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
||||||
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
|
- [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)
|
||||||
|
- [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
|
||||||
|
- [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k)
|
||||||
|
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
@@ -313,35 +392,34 @@ huggingface-cli login
|
|||||||
|
|
||||||
| Mandatory | Minimum | Recommend |
|
| Mandatory | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.9 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.4.0 |
|
| torch | 1.13.1 | 2.5.1 |
|
||||||
| transformers | 4.41.2 | 4.43.4 |
|
| transformers | 4.41.2 | 4.49.0 |
|
||||||
| datasets | 2.16.0 | 2.20.0 |
|
| datasets | 2.16.0 | 3.2.0 |
|
||||||
| accelerate | 0.30.1 | 0.32.0 |
|
| accelerate | 0.34.0 | 1.2.1 |
|
||||||
| peft | 0.11.1 | 0.12.0 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.6 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| Optional | Minimum | Recommend |
|
| Optional | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.16.4 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.5.0 |
|
| vllm | 0.4.3 | 0.7.3 |
|
||||||
| flash-attn | 2.3.0 | 2.6.3 |
|
| flash-attn | 2.3.0 | 2.7.2 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
\* *estimated*
|
\* *estimated*
|
||||||
|
|
||||||
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
| Method | Bits | 7B | 14B | 30B | 70B | `x`B |
|
||||||
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
|
||||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
|
||||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
|
||||||
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
|
||||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
|
||||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
@@ -356,47 +434,67 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
|
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||||
|
|
||||||
|
<details><summary>Setting up a virtual environment with <b>uv</b></summary>
|
||||||
|
|
||||||
|
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv sync --extra torch --extra metrics --prerelease=allow
|
||||||
|
```
|
||||||
|
|
||||||
|
Run LLaMA-Factory in the isolated environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details><summary>For Windows users</summary>
|
<details><summary>For Windows users</summary>
|
||||||
|
|
||||||
|
#### Install BitsAndBytes
|
||||||
|
|
||||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
#### Install Flash Attention-2
|
||||||
|
|
||||||
|
To enable FlashAttention-2 on the Windows platform, please use the script from [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) to compile and install it by yourself.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details><summary>For Ascend NPU users</summary>
|
<details><summary>For Ascend NPU users</summary>
|
||||||
|
|
||||||
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# replace the url according to your CANN version and devices
|
# replace the url according to your CANN version and devices
|
||||||
# install CANN Toolkit
|
# install CANN Toolkit
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
|
||||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
||||||
|
|
||||||
# install CANN Kernels
|
# install CANN Kernels
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
|
||||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
||||||
|
|
||||||
# set env variables
|
# set env variables
|
||||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
| Requirement | Minimum | Recommend |
|
| Requirement | Minimum | Recommend |
|
||||||
| ------------ | ------- | ----------- |
|
| ------------ | ------- | -------------- |
|
||||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||||
| torch | 2.1.0 | 2.1.0 |
|
| torch | 2.1.0 | 2.4.0 |
|
||||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
|
||||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||||
|
|
||||||
@@ -404,11 +502,45 @@ If you cannot infer model on NPU devices, try setting `do_sample: false` in the
|
|||||||
|
|
||||||
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
|
#### Install BitsAndBytes
|
||||||
|
|
||||||
|
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
||||||
|
|
||||||
|
1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install bitsandbytes from source
|
||||||
|
# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch
|
||||||
|
git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
|
||||||
|
cd bitsandbytes/
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
# Install the dependencies for the compilation tools. Note that the commands for this step may vary depending on the operating system. The following are provided for reference
|
||||||
|
apt-get install -y build-essential cmake
|
||||||
|
|
||||||
|
# Compile & install
|
||||||
|
cmake -DCOMPUTE_BACKEND=npu -S .
|
||||||
|
make
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install transformers from the main branch.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone -b main https://github.com/huggingface/transformers.git
|
||||||
|
cd transformers
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
|
||||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Please update `data/dataset_info.json` to use your custom dataset.
|
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||||
@@ -427,6 +559,8 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `llamafactory-cli help` to show help information.
|
> Use `llamafactory-cli help` to show help information.
|
||||||
|
>
|
||||||
|
> Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems.
|
||||||
|
|
||||||
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||||
|
|
||||||
@@ -476,6 +610,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
|||||||
docker run -dit --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@@ -500,6 +635,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
@@ -533,6 +669,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v ./saves:/app/saves \
|
-v ./saves:/app/saves \
|
||||||
@@ -553,6 +690,7 @@ docker exec -it llamafactory bash
|
|||||||
|
|
||||||
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||||
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
||||||
|
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
|
||||||
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||||
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||||
|
|
||||||
@@ -566,6 +704,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
||||||
|
>
|
||||||
|
> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
|
||||||
|
|
||||||
### Download from ModelScope Hub
|
### Download from ModelScope Hub
|
||||||
|
|
||||||
@@ -577,6 +717,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
|||||||
|
|
||||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||||
|
|
||||||
|
### Download from Modelers Hub
|
||||||
|
|
||||||
|
You can also use Modelers Hub to download models and datasets.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
|
||||||
|
|
||||||
### Use W&B Logger
|
### Use W&B Logger
|
||||||
|
|
||||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||||
@@ -588,6 +738,21 @@ run_name: test_run # optional
|
|||||||
|
|
||||||
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
|
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
|
||||||
|
|
||||||
|
### Use SwanLab Logger
|
||||||
|
|
||||||
|
To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental results, you need to add the following arguments to yaml files.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_run_name: test_run # optional
|
||||||
|
```
|
||||||
|
|
||||||
|
When launching training tasks, you can log in to SwanLab in three ways:
|
||||||
|
|
||||||
|
1. Add `swanlab_api_key=<your_api_key>` to the yaml file, and set it to your [API key](https://swanlab.cn/settings).
|
||||||
|
2. Set the environment variable `SWANLAB_API_KEY` to your [API key](https://swanlab.cn/settings).
|
||||||
|
3. Use the `swanlab login` command to complete the login.
|
||||||
|
|
||||||
## Projects using LLaMA Factory
|
## Projects using LLaMA Factory
|
||||||
|
|
||||||
If you have a project that should be incorporated, please contact via email or create a pull request.
|
If you have a project that should be incorporated, please contact via email or create a pull request.
|
||||||
@@ -675,16 +840,20 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
||||||
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
||||||
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
||||||
|
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
|
||||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
|
||||||
|
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -692,7 +861,7 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
313
README_zh.md
313
README_zh.md
@@ -1,19 +1,32 @@
|
|||||||

|

|
||||||
|
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||||
[](LICENSE)
|
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
|
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||||
|
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#使用了-llama-factory-的项目)
|
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
|
[](https://gitcode.com/zhengyaowei/LLaMA-Factory)
|
||||||
|
|
||||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
|
<h3 align="center">
|
||||||
|
使用零代码<a href="#快速开始">命令行</a>与 <a href="#llama-board-可视化微调由-gradio-驱动">Web UI</a> 轻松微调百余种大模型
|
||||||
|
</h3>
|
||||||
|
<p align="center">
|
||||||
|
<picture>
|
||||||
|
<img alt="Github trend" src="https://trendshift.io/api/badge/repositories/4535">
|
||||||
|
</picture>
|
||||||
|
</p>
|
||||||
|
|
||||||
[](https://trendshift.io/repositories/4535)
|
|
||||||
|
|
||||||
👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
|
👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||||
|
|
||||||
@@ -25,11 +38,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
选择你的打开方式:
|
选择你的打开方式:
|
||||||
|
|
||||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
|
||||||
- **PAI-DSW**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
|
||||||
- **本地机器**:请见[如何使用](#如何使用)
|
|
||||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
|
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||||
|
- **本地机器**:请见[如何使用](#如何使用)
|
||||||
|
- **PAI-DSW(免费试用)**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
|
||||||
|
- **Amazon SageMaker**:[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
@@ -41,6 +58,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
- [数据集](#数据集)
|
- [数据集](#数据集)
|
||||||
- [软硬件依赖](#软硬件依赖)
|
- [软硬件依赖](#软硬件依赖)
|
||||||
- [如何使用](#如何使用)
|
- [如何使用](#如何使用)
|
||||||
|
- [安装 LLaMA Factory](#安装-llama-factory)
|
||||||
|
- [数据准备](#数据准备)
|
||||||
|
- [快速开始](#快速开始)
|
||||||
|
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
|
||||||
|
- [构建 Docker](#构建-docker)
|
||||||
|
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
|
||||||
|
- [从魔搭社区下载](#从魔搭社区下载)
|
||||||
|
- [从魔乐社区下载](#从魔乐社区下载)
|
||||||
|
- [使用 W&B 面板](#使用-wb-面板)
|
||||||
|
- [使用 SwanLab 面板](#使用-swanlab-面板)
|
||||||
- [使用了 LLaMA Factory 的项目](#使用了-llama-factory-的项目)
|
- [使用了 LLaMA Factory 的项目](#使用了-llama-factory-的项目)
|
||||||
- [协议](#协议)
|
- [协议](#协议)
|
||||||
- [引用](#引用)
|
- [引用](#引用)
|
||||||
@@ -48,14 +75,22 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
## 项目特色
|
## 项目特色
|
||||||
|
|
||||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
|
||||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||||
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
|
||||||
- **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
|
||||||
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。
|
||||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
|
||||||
|
### 最新模型的 Day-N 微调适配
|
||||||
|
|
||||||
|
| 适配时间 | 模型名称 |
|
||||||
|
| ------------ | ---------------------------------------------------------- |
|
||||||
|
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
|
||||||
|
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
|
||||||
|
|
||||||
## 性能指标
|
## 性能指标
|
||||||
|
|
||||||
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
||||||
@@ -73,14 +108,38 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
|
||||||
|
|
||||||
|
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
|
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
|
||||||
|
|
||||||
|
[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
|
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
|
||||||
|
|
||||||
|
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
|
||||||
|
|
||||||
|
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
|
||||||
|
|
||||||
|
[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。
|
||||||
|
|
||||||
|
[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。
|
||||||
|
|
||||||
|
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||||
|
|
||||||
|
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
||||||
|
|
||||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
||||||
|
|
||||||
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
|
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
|
||||||
|
|
||||||
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
|
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
|
||||||
|
|
||||||
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
|
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
|
||||||
|
|
||||||
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
@@ -161,34 +220,51 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
| 模型名 | 模型大小 | Template |
|
| 模型名 | 参数量 | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ------------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
|
||||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
|
||||||
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
|
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||||
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
|
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||||
|
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||||
|
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||||
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
|
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
|
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||||
|
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl |
|
||||||
|
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
||||||
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||||
@@ -202,7 +278,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
## 训练方法
|
## 训练方法
|
||||||
|
|
||||||
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| --------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
@@ -272,9 +348,13 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||||
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
|
||||||
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
||||||
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
|
- [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)
|
||||||
|
- [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
|
||||||
|
- [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k)
|
||||||
|
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
@@ -314,35 +394,34 @@ huggingface-cli login
|
|||||||
|
|
||||||
| 必需项 | 至少 | 推荐 |
|
| 必需项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.9 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.4.0 |
|
| torch | 1.13.1 | 2.5.1 |
|
||||||
| transformers | 4.41.2 | 4.43.4 |
|
| transformers | 4.41.2 | 4.49.0 |
|
||||||
| datasets | 2.16.0 | 2.20.0 |
|
| datasets | 2.16.0 | 3.2.0 |
|
||||||
| accelerate | 0.30.1 | 0.32.0 |
|
| accelerate | 0.34.0 | 1.2.1 |
|
||||||
| peft | 0.11.1 | 0.12.0 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.6 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| 可选项 | 至少 | 推荐 |
|
| 可选项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.16.4 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.5.0 |
|
| vllm | 0.4.3 | 0.7.3 |
|
||||||
| flash-attn | 2.3.0 | 2.6.3 |
|
| flash-attn | 2.3.0 | 2.7.2 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
\* *估算值*
|
\* *估算值*
|
||||||
|
|
||||||
| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
| 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B |
|
||||||
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
|
||||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
|
||||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
|
||||||
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
|
||||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
|
||||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
@@ -357,26 +436,47 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
|
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||||
|
|
||||||
|
<details><summary>使用 <b>uv</b> 构建虚拟环境</summary>
|
||||||
|
|
||||||
|
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv sync --extra torch --extra metrics --prerelease=allow
|
||||||
|
```
|
||||||
|
|
||||||
|
在环境中运行 LLaMA-Factory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
<details><summary>Windows 用户指南</summary>
|
<details><summary>Windows 用户指南</summary>
|
||||||
|
|
||||||
|
#### 安装 BitsAndBytes
|
||||||
|
|
||||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
|
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
#### 安装 Flash Attention-2
|
||||||
|
|
||||||
|
如果要在 Windows 平台上开启 FlashAttention-2,请使用 [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) 中的脚本自行编译与安装。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details><summary>昇腾 NPU 用户指南</summary>
|
<details><summary>昇腾 NPU 用户指南</summary>
|
||||||
|
|
||||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||||
@@ -392,12 +492,12 @@ bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
|||||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
| 依赖项 | 至少 | 推荐 |
|
| 依赖项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | ----------- |
|
| ------------ | ------- | -------------- |
|
||||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||||
| torch | 2.1.0 | 2.1.0 |
|
| torch | 2.1.0 | 2.4.0 |
|
||||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
|
||||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||||
|
|
||||||
@@ -405,11 +505,45 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
|
|
||||||
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
|
#### 安装 BitsAndBytes
|
||||||
|
|
||||||
|
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
||||||
|
|
||||||
|
1. 手动编译 bitsandbytes:请参考[安装文档](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU)完成 NPU 版的 bitsandbytes 安装,编译要求环境 cmake 版本不低于 3.22.1,g++ 版本不低于 12.x。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从源码安装 bitsandbytes
|
||||||
|
# 克隆 bitsandbytes 仓库, Ascend NPU 目前在 multi-backend-refactor 中支持
|
||||||
|
git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
|
||||||
|
cd bitsandbytes/
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
# 安装编译工具依赖,该步骤在不同系统上命令有所不同,供参考
|
||||||
|
apt-get install -y build-essential cmake
|
||||||
|
|
||||||
|
# 编译 & 安装
|
||||||
|
cmake -DCOMPUTE_BACKEND=npu -S .
|
||||||
|
make
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 安装 transformers 的 main 分支版本。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone -b main https://github.com/huggingface/transformers.git
|
||||||
|
cd transformers
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 数据准备
|
### 数据准备
|
||||||
|
|
||||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||||
@@ -428,6 +562,8 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 使用 `llamafactory-cli help` 显示帮助信息。
|
> 使用 `llamafactory-cli help` 显示帮助信息。
|
||||||
|
>
|
||||||
|
> 遇到报错请先看[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)。
|
||||||
|
|
||||||
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||||
|
|
||||||
@@ -477,6 +613,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
|||||||
docker run -dit --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@@ -501,6 +638,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
@@ -534,6 +672,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v ./saves:/app/saves \
|
-v ./saves:/app/saves \
|
||||||
@@ -554,6 +693,7 @@ docker exec -it llamafactory bash
|
|||||||
|
|
||||||
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||||
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
||||||
|
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
|
||||||
- `data`:宿主机中存放数据集的文件夹路径。
|
- `data`:宿主机中存放数据集的文件夹路径。
|
||||||
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
||||||
|
|
||||||
@@ -567,6 +707,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
||||||
|
>
|
||||||
|
> 示例:[图像理解](scripts/api_example/test_image.py) | [工具调用](scripts/api_example/test_toolcall.py)
|
||||||
|
|
||||||
### 从魔搭社区下载
|
### 从魔搭社区下载
|
||||||
|
|
||||||
@@ -578,6 +720,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||||||
|
|
||||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||||
|
|
||||||
|
### 从魔乐社区下载
|
||||||
|
|
||||||
|
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
|
||||||
|
```
|
||||||
|
|
||||||
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`。
|
||||||
|
|
||||||
### 使用 W&B 面板
|
### 使用 W&B 面板
|
||||||
|
|
||||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||||
@@ -589,6 +741,21 @@ run_name: test_run # 可选
|
|||||||
|
|
||||||
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
|
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
|
||||||
|
|
||||||
|
### 使用 SwanLab 面板
|
||||||
|
|
||||||
|
若要使用 [SwanLab](https://github.com/SwanHubX/SwanLab) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_run_name: test_run # 可选
|
||||||
|
```
|
||||||
|
|
||||||
|
在启动训练任务时,登录SwanLab账户有以下三种方式:
|
||||||
|
|
||||||
|
方式一:在 yaml 文件中添加 `swanlab_api_key=<your_api_key>` ,并设置为你的 [API 密钥](https://swanlab.cn/settings)。
|
||||||
|
方式二:将环境变量 `SWANLAB_API_KEY` 设置为你的 [API 密钥](https://swanlab.cn/settings)。
|
||||||
|
方式三:启动前使用 `swanlab login` 命令完成登录。
|
||||||
|
|
||||||
## 使用了 LLaMA Factory 的项目
|
## 使用了 LLaMA Factory 的项目
|
||||||
|
|
||||||
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
|
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
|
||||||
@@ -676,16 +843,20 @@ run_name: test_run # 可选
|
|||||||
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
||||||
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
||||||
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
||||||
|
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
|
||||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。
|
||||||
|
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -693,7 +864,7 @@ run_name: test_run # 可选
|
|||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
|||||||
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
||||||
"images": "the column name in the dataset containing the image inputs. (default: None)",
|
"images": "the column name in the dataset containing the image inputs. (default: None)",
|
||||||
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
|
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
|
||||||
|
"audios": "the column name in the dataset containing the audios inputs. (default: None)",
|
||||||
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
|
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
|
||||||
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
|
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
|
||||||
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
|
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
|
||||||
@@ -150,6 +151,10 @@ An additional column `images` is required. Please refer to the [sharegpt](#share
|
|||||||
|
|
||||||
An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
|
An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
|
||||||
|
|
||||||
|
### Multimodal Audio Dataset
|
||||||
|
|
||||||
|
An additional column `audios` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
|
||||||
|
|
||||||
## Sharegpt Format
|
## Sharegpt Format
|
||||||
|
|
||||||
### Supervised Fine-Tuning Dataset
|
### Supervised Fine-Tuning Dataset
|
||||||
@@ -296,7 +301,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
|
|
||||||
- [Example dataset](mllm_demo.json)
|
- [Example dataset](mllm_demo.json)
|
||||||
|
|
||||||
Multimodal image datasets require a `images` column containing the paths to the input images.
|
Multimodal image datasets require an `images` column containing the paths to the input images.
|
||||||
|
|
||||||
The number of images should be identical to the `<image>` tokens in the conversations.
|
The number of images should be identical to the `<image>` tokens in the conversations.
|
||||||
|
|
||||||
@@ -374,6 +379,47 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Multimodal Audio Dataset
|
||||||
|
|
||||||
|
- [Example dataset](mllm_audio_demo.json)
|
||||||
|
|
||||||
|
Multimodal audio datasets require an `audios` column containing the paths to the input audios.
|
||||||
|
|
||||||
|
The number of audios should be identical to the `<audio>` tokens in the conversations.
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "<audio>human instruction"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "model response"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"audios": [
|
||||||
|
"audio path (required)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"dataset_name": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"audios": "audios"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### OpenAI Format
|
### OpenAI Format
|
||||||
|
|
||||||
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
|
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
||||||
"images": "数据集代表图像输入的表头名称(默认:None)",
|
"images": "数据集代表图像输入的表头名称(默认:None)",
|
||||||
"videos": "数据集代表视频输入的表头名称(默认:None)",
|
"videos": "数据集代表视频输入的表头名称(默认:None)",
|
||||||
|
"audios": "数据集代表音频输入的表头名称(默认:None)",
|
||||||
"chosen": "数据集代表更优回答的表头名称(默认:None)",
|
"chosen": "数据集代表更优回答的表头名称(默认:None)",
|
||||||
"rejected": "数据集代表更差回答的表头名称(默认:None)",
|
"rejected": "数据集代表更差回答的表头名称(默认:None)",
|
||||||
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
|
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
|
||||||
@@ -150,6 +151,10 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
|
|||||||
|
|
||||||
多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
|
多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
|
||||||
|
|
||||||
|
### 多模态音频数据集
|
||||||
|
|
||||||
|
多模态音频数据集需要提供额外的 `audios` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
|
||||||
|
|
||||||
## Sharegpt 格式
|
## Sharegpt 格式
|
||||||
|
|
||||||
### 指令监督微调数据集
|
### 指令监督微调数据集
|
||||||
@@ -374,6 +379,48 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 多模态音频数据集
|
||||||
|
|
||||||
|
- [样例数据集](mllm_audio_demo.json)
|
||||||
|
|
||||||
|
多模态音频数据集需要额外添加一个 `audios` 列,包含输入音频的路径。
|
||||||
|
|
||||||
|
注意音频的数量必须与文本中所有 `<audio>` 标记的数量严格一致。
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "<audio>人类指令"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "模型回答"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"audios": [
|
||||||
|
"音频路径(必填)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"数据集名称": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"audios": "audios"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### OpenAI 格式
|
### OpenAI 格式
|
||||||
|
|
||||||
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
|
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
||||||
_LICENSE = "gpl-3.0"
|
_LICENSE = "gpl-3.0"
|
||||||
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
||||||
|
|
||||||
|
|
||||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
@@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str):
|
def _generate_examples(self, filepath: str):
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for key, row in enumerate(f):
|
for key, row in enumerate(f):
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
conversations = []
|
conversations = []
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ import datasets
|
|||||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||||
_CITATION = ""
|
_CITATION = ""
|
||||||
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
||||||
_LICENSE = "mit"
|
_LICENSE = "mit"
|
||||||
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||||
_URLS = {
|
_URLS = {
|
||||||
"train": [
|
"train": [
|
||||||
_URL + "harmless-base/train.jsonl.gz",
|
_URL + "harmless-base/train.jsonl.gz",
|
||||||
@@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
|||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
key = 0
|
key = 0
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
chosen = data["chosen"]
|
chosen = data["chosen"]
|
||||||
|
|||||||
BIN
data/mllm_demo_data/1.mp3
Normal file
BIN
data/mllm_demo_data/1.mp3
Normal file
Binary file not shown.
BIN
data/mllm_demo_data/2.wav
Normal file
BIN
data/mllm_demo_data/2.wav
Normal file
Binary file not shown.
BIN
data/mllm_demo_data/3.flac
Normal file
BIN
data/mllm_demo_data/3.flac
Normal file
Binary file not shown.
@@ -20,9 +20,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
||||||
_LICENSE = "cc-by-nc-4.0"
|
_LICENSE = "cc-by-nc-4.0"
|
||||||
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
||||||
|
|
||||||
|
|
||||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||||
@@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
|||||||
|
|
||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
try:
|
try:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
# Default use the NVIDIA official image with PyTorch 2.3.0
|
||||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
|
||||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
|
||||||
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
# Define environments
|
# Define environments
|
||||||
ENV MAX_JOBS=4
|
ENV MAX_JOBS=4
|
||||||
@@ -12,17 +13,32 @@ ARG INSTALL_BNB=false
|
|||||||
ARG INSTALL_VLLM=false
|
ARG INSTALL_VLLM=false
|
||||||
ARG INSTALL_DEEPSPEED=false
|
ARG INSTALL_DEEPSPEED=false
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
|
ARG INSTALL_EETQ=false
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
ARG HTTP_PROXY=
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Set http proxy
|
||||||
|
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
echo "Configuring proxy..."; \
|
||||||
|
export http_proxy=$HTTP_PROXY; \
|
||||||
|
export https_proxy=$HTTP_PROXY; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Install the requirements
|
# Install the requirements
|
||||||
COPY requirements.txt /app
|
COPY requirements.txt /app
|
||||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
pip config set global.extra-index-url "$PIP_INDEX" && \
|
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||||
python -m pip install --upgrade pip && \
|
python -m pip install --upgrade pip && \
|
||||||
python -m pip install -r requirements.txt
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||||
|
else \
|
||||||
|
python -m pip install -r requirements.txt; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Copy the rest of the application into the image
|
# Copy the rest of the application into the image
|
||||||
COPY . /app
|
COPY . /app
|
||||||
@@ -38,13 +54,39 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
pip install -e ".[$EXTRA_PACKAGES]"
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_EETQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
|
||||||
|
fi; \
|
||||||
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||||
|
else \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN pip uninstall -y transformer-engine flash-attn && \
|
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||||
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||||
pip uninstall -y ninja && pip install ninja && \
|
pip uninstall -y ninja && \
|
||||||
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
pip install --proxy=$HTTP_PROXY ninja && \
|
||||||
|
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
|
||||||
|
else \
|
||||||
|
pip install ninja && \
|
||||||
|
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||||
|
fi; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# Unset http proxy
|
||||||
|
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
unset http_proxy; \
|
||||||
|
unset https_proxy; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
|
|||||||
@@ -4,15 +4,19 @@ services:
|
|||||||
dockerfile: ./docker/docker-cuda/Dockerfile
|
dockerfile: ./docker/docker-cuda/Dockerfile
|
||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
INSTALL_BNB: false
|
INSTALL_BNB: "false"
|
||||||
INSTALL_VLLM: false
|
INSTALL_VLLM: "false"
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: "false"
|
||||||
INSTALL_FLASHATTN: false
|
INSTALL_FLASHATTN: "false"
|
||||||
|
INSTALL_LIGER_KERNEL: "false"
|
||||||
|
INSTALL_HQQ: "false"
|
||||||
|
INSTALL_EETQ: "false"
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
ports:
|
ports:
|
||||||
@@ -20,6 +24,7 @@ services:
|
|||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
ipc: host
|
ipc: host
|
||||||
tty: true
|
tty: true
|
||||||
|
shm_size: "16gb"
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
command: bash
|
command: bash
|
||||||
deploy:
|
deploy:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
|
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
|
||||||
# More versions can be found at https://hub.docker.com/r/ascendai/cann/tags
|
# More versions can be found at https://hub.docker.com/r/ascendai/cann/tags
|
||||||
# FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8
|
# FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8
|
||||||
FROM ascendai/cann:8.0.rc1-910b-ubuntu22.04-py3.8
|
FROM ascendai/cann:8.0.0-910b-ubuntu22.04-py3.10
|
||||||
# FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8
|
# FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8
|
||||||
# FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8
|
# FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8
|
||||||
|
|
||||||
@@ -12,16 +12,28 @@ ENV DEBIAN_FRONTEND=noninteractive
|
|||||||
ARG INSTALL_DEEPSPEED=false
|
ARG INSTALL_DEEPSPEED=false
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
|
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||||
|
ARG HTTP_PROXY=
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Set http proxy
|
||||||
|
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
echo "Configuring proxy..."; \
|
||||||
|
export http_proxy=$HTTP_PROXY; \
|
||||||
|
export https_proxy=$HTTP_PROXY; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Install the requirements
|
# Install the requirements
|
||||||
COPY requirements.txt /app
|
COPY requirements.txt /app
|
||||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
pip config set global.extra-index-url "$TORCH_INDEX" && \
|
pip config set global.extra-index-url "$TORCH_INDEX" && \
|
||||||
python -m pip install --upgrade pip && \
|
python -m pip install --upgrade pip && \
|
||||||
python -m pip install -r requirements.txt
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||||
|
else \
|
||||||
|
python -m pip install -r requirements.txt; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Copy the rest of the application into the image
|
# Copy the rest of the application into the image
|
||||||
COPY . /app
|
COPY . /app
|
||||||
@@ -31,7 +43,17 @@ RUN EXTRA_PACKAGES="torch-npu,metrics"; \
|
|||||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
pip install -e ".[$EXTRA_PACKAGES]"
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||||
|
else \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Unset http proxy
|
||||||
|
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
unset http_proxy; \
|
||||||
|
unset https_proxy; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ services:
|
|||||||
dockerfile: ./docker/docker-npu/Dockerfile
|
dockerfile: ./docker/docker-npu/Dockerfile
|
||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: "false"
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
- /usr/local/dcmi:/usr/local/dcmi
|
- /usr/local/dcmi:/usr/local/dcmi
|
||||||
@@ -21,6 +22,7 @@ services:
|
|||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
ipc: host
|
ipc: host
|
||||||
tty: true
|
tty: true
|
||||||
|
shm_size: "16gb"
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
command: bash
|
command: bash
|
||||||
devices:
|
devices:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM hardandheavy/transformers-rocm:2.1.0
|
FROM hardandheavy/transformers-rocm:2.2.0
|
||||||
|
|
||||||
# Define environments
|
# Define environments
|
||||||
ENV MAX_JOBS=4
|
ENV MAX_JOBS=4
|
||||||
@@ -10,17 +10,31 @@ ARG INSTALL_BNB=false
|
|||||||
ARG INSTALL_VLLM=false
|
ARG INSTALL_VLLM=false
|
||||||
ARG INSTALL_DEEPSPEED=false
|
ARG INSTALL_DEEPSPEED=false
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
ARG HTTP_PROXY=
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Set http proxy
|
||||||
|
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
echo "Configuring proxy..."; \
|
||||||
|
export http_proxy=$HTTP_PROXY; \
|
||||||
|
export https_proxy=$HTTP_PROXY; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Install the requirements
|
# Install the requirements
|
||||||
COPY requirements.txt /app
|
COPY requirements.txt /app
|
||||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
pip config set global.extra-index-url "$PIP_INDEX" && \
|
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||||
python -m pip install --upgrade pip && \
|
python -m pip install --upgrade pip && \
|
||||||
python -m pip install -r requirements.txt
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||||
|
else \
|
||||||
|
python -m pip install -r requirements.txt; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Copy the rest of the application into the image
|
# Copy the rest of the application into the image
|
||||||
COPY . /app
|
COPY . /app
|
||||||
@@ -36,13 +50,35 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
pip install -e ".[$EXTRA_PACKAGES]"
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||||
|
else \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN pip uninstall -y transformer-engine flash-attn && \
|
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||||
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||||
pip uninstall -y ninja && pip install ninja && \
|
pip uninstall -y ninja && \
|
||||||
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
pip install --proxy=$HTTP_PROXY ninja && \
|
||||||
|
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
|
||||||
|
else \
|
||||||
|
pip install ninja && \
|
||||||
|
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||||
|
fi; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Unset http proxy
|
||||||
|
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||||
|
unset http_proxy; \
|
||||||
|
unset https_proxy; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
|
|||||||
@@ -4,15 +4,18 @@ services:
|
|||||||
dockerfile: ./docker/docker-rocm/Dockerfile
|
dockerfile: ./docker/docker-rocm/Dockerfile
|
||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
INSTALL_BNB: false
|
INSTALL_BNB: "false"
|
||||||
INSTALL_VLLM: false
|
INSTALL_VLLM: "false"
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: "false"
|
||||||
INSTALL_FLASHATTN: false
|
INSTALL_FLASHATTN: "false"
|
||||||
|
INSTALL_LIGER_KERNEL: "false"
|
||||||
|
INSTALL_HQQ: "false"
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
- ../../saves:/app/saves
|
- ../../saves:/app/saves
|
||||||
@@ -21,6 +24,7 @@ services:
|
|||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
ipc: host
|
ipc: host
|
||||||
tty: true
|
tty: true
|
||||||
|
shm_size: "16gb"
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
command: bash
|
command: bash
|
||||||
devices:
|
devices:
|
||||||
|
|||||||
@@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
|||||||
df = pd.read_csv(filepath, header=None)
|
df = pd.read_csv(filepath, header=None)
|
||||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||||
|
|
||||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
yield from enumerate(df.to_dict(orient="records"))
|
||||||
yield i, instance
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
|
|||||||
|
|
||||||
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
|
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
|
||||||
|
|
||||||
|
By default, LLaMA-Factory uses all visible computing devices.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### LoRA Fine-Tuning
|
### LoRA Fine-Tuning
|
||||||
@@ -80,17 +82,11 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
|||||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||||
@@ -99,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Supervised Fine-Tuning with Ray on 4 GPUs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||||
|
```
|
||||||
|
|
||||||
### QLoRA Fine-Tuning
|
### QLoRA Fine-Tuning
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
||||||
@@ -107,6 +109,12 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
|
|||||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -130,14 +138,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
|||||||
#### Supervised Fine-Tuning on Single Node
|
#### Supervised Fine-Tuning on Single Node
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Multimodal Supervised Fine-Tuning
|
#### Multimodal Supervised Fine-Tuning
|
||||||
@@ -146,12 +154,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### Merging LoRA Adapters and Quantization
|
### Merging LoRA Adapters and Quantization
|
||||||
|
|
||||||
#### Merge LoRA Adapters
|
#### Merge LoRA Adapters
|
||||||
@@ -168,15 +170,27 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
|||||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Save Ollama modelfile
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
### Inferring LoRA Fine-Tuned Models
|
### Inferring LoRA Fine-Tuned Models
|
||||||
|
|
||||||
#### Use CLI
|
#### Batch Generation using vLLM Tensor Parallel
|
||||||
|
|
||||||
|
```
|
||||||
|
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Use CLI ChatBox
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Use Web UI
|
#### Use Web UI ChatBox
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
@@ -196,6 +210,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
|||||||
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Full-Parameter Fine-Tuning using APOLLO
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### Full-Parameter Fine-Tuning using BAdam
|
#### Full-Parameter Fine-Tuning using BAdam
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -238,3 +258,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
|||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/train.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Computing BLEU and ROUGE Scores
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
|
||||||
|
```
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
|
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
|
||||||
|
|
||||||
|
LLaMA-Factory 默认使用所有可见的计算设备。
|
||||||
|
|
||||||
## 示例
|
## 示例
|
||||||
|
|
||||||
### LoRA 微调
|
### LoRA 微调
|
||||||
@@ -80,17 +82,11 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
|||||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 多机指令监督微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
@@ -99,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 使用 Ray 在 4 张 GPU 上微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||||
|
```
|
||||||
|
|
||||||
### QLoRA 微调
|
### QLoRA 微调
|
||||||
|
|
||||||
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
||||||
@@ -107,6 +109,12 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
|
|||||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -130,14 +138,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
|||||||
#### 在单机上进行指令监督微调
|
#### 在单机上进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 在多机上进行指令监督微调
|
#### 在多机上进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 多模态指令监督微调
|
#### 多模态指令监督微调
|
||||||
@@ -146,12 +154,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### 合并 LoRA 适配器与模型量化
|
### 合并 LoRA 适配器与模型量化
|
||||||
|
|
||||||
#### 合并 LoRA 适配器
|
#### 合并 LoRA 适配器
|
||||||
@@ -168,15 +170,27 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
|||||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 保存 Ollama 配置文件
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
### 推理 LoRA 模型
|
### 推理 LoRA 模型
|
||||||
|
|
||||||
#### 使用命令行接口
|
#### 使用 vLLM+TP 批量推理
|
||||||
|
|
||||||
|
```
|
||||||
|
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用命令行对话框
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用浏览器界面
|
#### 使用浏览器对话框
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
@@ -196,6 +210,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
|||||||
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 使用 APOLLO 进行全参数训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### 使用 BAdam 进行全参数训练
|
#### 使用 BAdam 进行全参数训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -238,3 +258,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
|||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/train.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 计算 BLEU 和 ROUGE 分数
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
|
||||||
|
```
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ fsdp_config:
|
|||||||
fsdp_use_orig_params: true
|
fsdp_use_orig_params: true
|
||||||
machine_rank: 0
|
machine_rank: 0
|
||||||
main_training_function: main
|
main_training_function: main
|
||||||
mixed_precision: fp16 # or bf16
|
mixed_precision: bf16 # or fp16
|
||||||
num_machines: 1 # the number of nodes
|
num_machines: 1 # the number of nodes
|
||||||
num_processes: 2 # the number of GPUs in all nodes
|
num_processes: 2 # the number of GPUs in all nodes
|
||||||
rdzv_backend: static
|
rdzv_backend: static
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-1.5B-Instruct
|
model_name_or_path: Qwen/Qwen2-1.5B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -10,7 +11,7 @@ use_adam_mini: true
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: qwen
|
template: qwen
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -33,7 +34,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
45
examples/extras/apollo/llama3_full_sft.yaml
Normal file
45
examples/extras/apollo/llama3_full_sft.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
use_apollo: true
|
||||||
|
apollo_layerwise: true # choices: [true, false], use false for DDP training
|
||||||
|
apollo_target: all
|
||||||
|
apollo_rank: 128
|
||||||
|
apollo_scale: 32.0
|
||||||
|
apollo_scale_type: channel
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
template: llama3
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/llama3-8b/full/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 1 # use 1 for layerwise apollo
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
pure_bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -15,7 +16,7 @@ badam_verbose: 2
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -36,7 +37,7 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
quantization_bit: 4
|
quantization_bit: 4
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -34,7 +36,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,20 +1,21 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
use_galore: true
|
use_galore: true
|
||||||
galore_layerwise: true
|
galore_layerwise: true # choices: [true, false], use false for DDP training
|
||||||
galore_target: mlp,self_attn
|
galore_target: all
|
||||||
galore_rank: 128
|
galore_rank: 128
|
||||||
galore_scale: 2.0
|
galore_scale: 2.0
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -28,7 +29,7 @@ overwrite_output_dir: true
|
|||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1 # use 1 for layerwise galore
|
||||||
learning_rate: 1.0e-5
|
learning_rate: 1.0e-5
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
@@ -37,7 +38,7 @@ pure_bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: models/llama3-8b-pro
|
model_name_or_path: models/llama3-8b-pro
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -12,7 +13,7 @@ use_llama_pro: true
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -35,7 +36,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
loraplus_lr_ratio: 16.0
|
loraplus_lr_ratio: 16.0
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -34,7 +36,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -10,7 +11,7 @@ mixture_of_depths: convert
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -34,7 +35,7 @@ pure_bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
# The batch generation can be SLOW using this config.
|
||||||
|
# For faster inference, we recommend to use `scripts/vllm_infer.py`.
|
||||||
|
|
||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -10,7 +14,7 @@ finetuning_type: lora
|
|||||||
### dataset
|
### dataset
|
||||||
eval_dataset: identity,alpaca_en_demo
|
eval_dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
pissa_init: true
|
pissa_init: true
|
||||||
pissa_iter: 16
|
pissa_iter: 16
|
||||||
@@ -13,7 +15,7 @@ pissa_convert: true
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -36,7 +38,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
template: llama3
|
template: llama3
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
trust_remote_code: true
|
||||||
|
|||||||
4
examples/inference/llama3_full_sft.yaml
Normal file
4
examples/inference/llama3_full_sft.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
model_name_or_path: saves/llama3-8b/full/sft
|
||||||
|
template: llama3
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
trust_remote_code: true
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
finetuning_type: lora
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
trust_remote_code: true
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
template: llama3
|
template: llama3
|
||||||
infer_backend: vllm
|
infer_backend: vllm
|
||||||
vllm_enforce_eager: true
|
vllm_enforce_eager: true
|
||||||
|
trust_remote_code: true
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||||
template: llava
|
template: llava
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
trust_remote_code: true
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
trust_remote_code: true
|
||||||
|
|||||||
10
examples/merge_lora/llama3_full_sft.yaml
Normal file
10
examples/merge_lora/llama3_full_sft.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: saves/llama3-8b/full/sft
|
||||||
|
template: llama3
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### export
|
||||||
|
export_dir: output/llama3_full_sft
|
||||||
|
export_size: 5
|
||||||
|
export_device: cpu
|
||||||
|
export_legacy_format: false
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
template: llama3
|
template: llama3
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### export
|
### export
|
||||||
export_dir: models/llama3_gptq
|
export_dir: output/llama3_gptq
|
||||||
export_quantization_bit: 4
|
export_quantization_bit: 4
|
||||||
export_quantization_dataset: data/c4_demo.json
|
export_quantization_dataset: data/c4_demo.json
|
||||||
export_size: 2
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
|||||||
@@ -4,10 +4,10 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
finetuning_type: lora
|
trust_remote_code: true
|
||||||
|
|
||||||
### export
|
### export
|
||||||
export_dir: models/llama3_lora_sft
|
export_dir: output/llama3_lora_sft
|
||||||
export_size: 2
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
|||||||
@@ -4,10 +4,10 @@
|
|||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
finetuning_type: lora
|
trust_remote_code: true
|
||||||
|
|
||||||
### export
|
### export
|
||||||
export_dir: models/qwen2_vl_lora_sft
|
export_dir: output/qwen2_vl_lora_sft
|
||||||
export_size: 2
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
### model
|
|
||||||
model_name_or_path: saves/llama3-8b/full/sft
|
|
||||||
|
|
||||||
### method
|
|
||||||
stage: sft
|
|
||||||
do_predict: true
|
|
||||||
finetuning_type: full
|
|
||||||
|
|
||||||
### dataset
|
|
||||||
eval_dataset: identity,alpaca_en_demo
|
|
||||||
template: llama3
|
|
||||||
cutoff_len: 1024
|
|
||||||
max_samples: 50
|
|
||||||
overwrite_cache: true
|
|
||||||
preprocessing_num_workers: 16
|
|
||||||
|
|
||||||
### output
|
|
||||||
output_dir: saves/llama3-8b/full/predict
|
|
||||||
overwrite_output_dir: true
|
|
||||||
|
|
||||||
### eval
|
|
||||||
per_device_eval_batch_size: 1
|
|
||||||
predict_with_generate: true
|
|
||||||
@@ -1,19 +1,21 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/full/sft
|
output_dir: saves/llama3-8b/full/sft
|
||||||
@@ -21,6 +23,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -31,9 +34,11 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# eval_dataset: alpaca_en_demo
|
||||||
per_device_eval_batch_size: 1
|
# val_size: 0.1
|
||||||
eval_strategy: steps
|
# per_device_eval_batch_size: 1
|
||||||
eval_steps: 500
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
@@ -1,19 +1,26 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
freeze_vision_tower: true # choices: [true, false]
|
||||||
|
freeze_multi_modal_projector: true # choices: [true, false]
|
||||||
|
freeze_language_model: false # choices: [true, false]
|
||||||
|
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: mllm_demo,identity
|
dataset: mllm_demo,identity,alpaca_en_demo
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/qwen2_vl-7b/full/sft
|
output_dir: saves/qwen2_vl-7b/full/sft
|
||||||
@@ -21,6 +28,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -31,9 +39,10 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: dpo
|
stage: dpo
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
pref_beta: 0.1
|
pref_beta: 0.1
|
||||||
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||||
@@ -12,10 +14,11 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: dpo_en_demo
|
dataset: dpo_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/lora/dpo
|
output_dir: saves/llama3-8b/lora/dpo
|
||||||
@@ -23,6 +26,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -33,9 +37,11 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# eval_dataset: dpo_en_demo
|
||||||
per_device_eval_batch_size: 1
|
# val_size: 0.1
|
||||||
eval_strategy: steps
|
# per_device_eval_batch_size: 1
|
||||||
eval_steps: 500
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: kto
|
stage: kto
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
pref_beta: 0.1
|
pref_beta: 0.1
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: kto_en_demo
|
dataset: kto_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -34,7 +36,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
reward_model: saves/llama3-8b/lora/reward
|
reward_model: saves/llama3-8b/lora/reward
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: ppo
|
stage: ppo
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: pt
|
stage: pt
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: c4_demo
|
dataset: c4_demo
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/lora/pretrain
|
output_dir: saves/llama3-8b/lora/pretrain
|
||||||
@@ -20,6 +23,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -30,9 +34,11 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# eval_dataset: c4_demo
|
||||||
per_device_eval_batch_size: 1
|
# val_size: 0.1
|
||||||
eval_strategy: steps
|
# per_device_eval_batch_size: 1
|
||||||
eval_steps: 500
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: rm
|
stage: rm
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: dpo_en_demo
|
dataset: dpo_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/lora/reward
|
output_dir: saves/llama3-8b/lora/reward
|
||||||
@@ -21,6 +24,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -31,9 +35,11 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# eval_dataset: dpo_en_demo
|
||||||
per_device_eval_batch_size: 1
|
# val_size: 0.1
|
||||||
eval_strategy: steps
|
# per_device_eval_batch_size: 1
|
||||||
eval_steps: 500
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/lora/sft
|
output_dir: saves/llama3-8b/lora/sft
|
||||||
@@ -21,6 +24,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -31,9 +35,11 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# eval_dataset: alpaca_en_demo
|
||||||
per_device_eval_batch_size: 1
|
# val_size: 0.1
|
||||||
eval_strategy: steps
|
# per_device_eval_batch_size: 1
|
||||||
eval_steps: 500
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,20 +1,23 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/lora/sft
|
output_dir: saves/llama3-8b/lora/sft
|
||||||
@@ -22,6 +25,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -32,9 +36,11 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# eval_dataset: alpaca_en_demo
|
||||||
per_device_eval_batch_size: 1
|
# val_size: 0.1
|
||||||
eval_strategy: steps
|
# per_device_eval_batch_size: 1
|
||||||
eval_steps: 500
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
|
|||||||
54
examples/train_lora/llama3_lora_sft_ray.yaml
Normal file
54
examples/train_lora/llama3_lora_sft_ray.yaml
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
|
||||||
|
template: llama3
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: tmp_dir
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
|
### ray
|
||||||
|
ray_run_name: llama3_8b_sft_lora
|
||||||
|
ray_storage_path: ./saves
|
||||||
|
ray_num_workers: 4 # number of GPUs to use
|
||||||
|
resources_per_worker:
|
||||||
|
GPU: 1
|
||||||
|
placement_strategy: PACK
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# eval_dataset: alpaca_en_demo
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
@@ -1,16 +1,18 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: mllm_demo
|
dataset: mllm_demo
|
||||||
template: llava
|
template: llava
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llava1_5-7b/lora/sft
|
output_dir: saves/llava1_5-7b/lora/sft
|
||||||
@@ -21,6 +24,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -31,9 +35,10 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: dpo
|
stage: dpo
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
pref_beta: 0.1
|
pref_beta: 0.1
|
||||||
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||||
@@ -12,10 +16,11 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: rlhf_v
|
dataset: rlhf_v
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/qwen2_vl-7b/lora/dpo
|
output_dir: saves/qwen2_vl-7b/lora/dpo
|
||||||
@@ -23,6 +28,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -33,9 +39,10 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,19 +1,24 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: mllm_demo,identity # video: mllm_video_demo
|
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/qwen2_vl-7b/lora/sft
|
output_dir: saves/qwen2_vl-7b/lora/sft
|
||||||
@@ -21,6 +26,7 @@ logging_steps: 10
|
|||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
@@ -31,9 +37,10 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
|
model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -33,7 +35,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
|
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -33,7 +35,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
quantization_bit: 4
|
||||||
|
quantization_method: bitsandbytes
|
||||||
|
double_quantization: false
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -25,7 +29,7 @@ overwrite_output_dir: true
|
|||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 8
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
@@ -34,7 +38,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
@@ -1,16 +1,18 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
|
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -33,7 +35,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -2,17 +2,19 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
quantization_bit: 4
|
quantization_bit: 4
|
||||||
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
|
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -35,7 +37,7 @@ bf16: true
|
|||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
# val_size: 0.1
|
||||||
per_device_eval_batch_size: 1
|
# per_device_eval_batch_size: 1
|
||||||
eval_strategy: steps
|
# eval_strategy: steps
|
||||||
eval_steps: 500
|
# eval_steps: 500
|
||||||
|
|||||||
@@ -2,6 +2,22 @@
|
|||||||
requires = ["setuptools>=61.0"]
|
requires = ["setuptools>=61.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "llamafactory"
|
||||||
|
dynamic = [
|
||||||
|
"version",
|
||||||
|
"dependencies",
|
||||||
|
"optional-dependencies",
|
||||||
|
"requires-python",
|
||||||
|
"scripts",
|
||||||
|
"authors",
|
||||||
|
"description",
|
||||||
|
"readme",
|
||||||
|
"license",
|
||||||
|
"keywords",
|
||||||
|
"classifiers"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py38"
|
target-version = "py38"
|
||||||
line-length = 119
|
line-length = 119
|
||||||
@@ -31,3 +47,19 @@ indent-style = "space"
|
|||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
skip-magic-trailing-comma = false
|
skip-magic-trailing-comma = false
|
||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
conflicts = [
|
||||||
|
[
|
||||||
|
{ extra = "torch-npu" },
|
||||||
|
{ extra = "aqlm" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ extra = "torch-npu" },
|
||||||
|
{ extra = "liger-kernel" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ extra = "torch-npu" },
|
||||||
|
{ extra = "vllm" },
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
|
||||||
datasets>=2.16.0,<=2.21.0
|
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
|
||||||
accelerate>=0.30.1,<=0.33.0
|
datasets>=2.16.0,<=3.2.0
|
||||||
|
accelerate>=0.34.0,<=1.2.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
gradio>=4.0.0
|
tokenizers>=0.19.0,<=0.21.0
|
||||||
|
gradio>=4.38.0,<=5.21.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
@@ -19,3 +21,6 @@ fire
|
|||||||
packaging
|
packaging
|
||||||
pyyaml
|
pyyaml
|
||||||
numpy<2.0.0
|
numpy<2.0.0
|
||||||
|
av
|
||||||
|
librosa
|
||||||
|
tyro<0.9.0
|
||||||
|
|||||||
65
scripts/api_example/test_image.py
Normal file
65
scripts/api_example/test_image.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||||
|
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||||
|
)
|
||||||
|
messages = []
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Output the color and number of each box."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 1:", result.choices[0].message.content)
|
||||||
|
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What kind of flower is this?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 2:", result.choices[0].message.content)
|
||||||
|
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 the LlamaFactory team.
|
||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 the LlamaFactory team.
|
||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -20,15 +19,10 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers.modeling_utils import (
|
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
shard_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
@@ -41,38 +35,46 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||||
baichuan2_state_dict.update(shard_weight)
|
baichuan2_state_dict.update(shard_weight)
|
||||||
|
|
||||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
|
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
|
||||||
if "W_pack" in key:
|
if "W_pack" in key:
|
||||||
proj_size = value.size(0) // 3
|
proj_size = value.size(0) // 3
|
||||||
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
llama_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
||||||
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
|
llama_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
|
||||||
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
|
llama_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
|
||||||
elif "lm_head" in key:
|
elif "lm_head" in key:
|
||||||
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
llama_state_dict[key] = torch.nn.functional.normalize(value)
|
||||||
else:
|
else:
|
||||||
llama2_state_dict[key] = value
|
llama_state_dict[key] = value
|
||||||
|
|
||||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||||
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
llama_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
|
||||||
|
)
|
||||||
|
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
|
||||||
|
shard = {tensor: llama_state_dict[tensor].contiguous() for tensor in tensors}
|
||||||
if save_safetensors:
|
if save_safetensors:
|
||||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||||
else:
|
else:
|
||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if not state_dict_split.is_sharded:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
|
||||||
else:
|
else:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
|
||||||
|
print(f"Model weights saved in {output_dir}.")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str):
|
def save_config(input_dir: str, output_dir: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||||
@@ -82,7 +84,8 @@ def save_config(input_dir: str, output_dir: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
|
||||||
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 the LlamaFactory team.
|
||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -20,16 +19,11 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers.modeling_utils import (
|
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
shard_checkpoint,
|
|
||||||
)
|
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@@ -50,66 +44,74 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
qwen_state_dict[key] = f.get_tensor(key)
|
qwen_state_dict[key] = f.get_tensor(key)
|
||||||
|
|
||||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
torch_dtype = None
|
torch_dtype = None
|
||||||
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
|
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
|
||||||
if torch_dtype is None:
|
if torch_dtype is None:
|
||||||
torch_dtype = value.dtype
|
torch_dtype = value.dtype
|
||||||
if "wte" in key:
|
if "wte" in key:
|
||||||
llama2_state_dict["model.embed_tokens.weight"] = value
|
llama_state_dict["model.embed_tokens.weight"] = value
|
||||||
elif "ln_f" in key:
|
elif "ln_f" in key:
|
||||||
llama2_state_dict["model.norm.weight"] = value
|
llama_state_dict["model.norm.weight"] = value
|
||||||
else:
|
else:
|
||||||
key = key.replace("transformer.h", "model.layers")
|
key = key.replace("transformer.h", "model.layers")
|
||||||
if "attn.c_attn" in key:
|
if "attn.c_attn" in key:
|
||||||
proj_size = value.size(0) // 3
|
proj_size = value.size(0) // 3
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
llama_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
|
llama_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
|
||||||
proj_size : 2 * proj_size, ...
|
proj_size : 2 * proj_size, ...
|
||||||
]
|
]
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
|
llama_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
|
||||||
elif "attn.c_proj" in key:
|
elif "attn.c_proj" in key:
|
||||||
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
llama_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
||||||
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
|
llama_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
|
||||||
value[:, 0]
|
value[:, 0]
|
||||||
).squeeze()
|
).squeeze()
|
||||||
elif "ln_1" in key:
|
elif "ln_1" in key:
|
||||||
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
llama_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
||||||
elif "ln_2" in key:
|
elif "ln_2" in key:
|
||||||
llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
|
llama_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
|
||||||
elif "mlp.w1" in key:
|
elif "mlp.w1" in key:
|
||||||
llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
|
llama_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
|
||||||
elif "mlp.w2" in key:
|
elif "mlp.w2" in key:
|
||||||
llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
|
llama_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
|
||||||
elif "mlp.c_proj" in key:
|
elif "mlp.c_proj" in key:
|
||||||
llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
|
llama_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
|
||||||
elif "lm_head" in key:
|
elif "lm_head" in key:
|
||||||
llama2_state_dict[key] = value
|
llama_state_dict[key] = value
|
||||||
else:
|
else:
|
||||||
raise KeyError("Unable to process key {}".format(key))
|
raise KeyError(f"Unable to process key {key}")
|
||||||
|
|
||||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||||
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
llama_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
|
||||||
|
)
|
||||||
|
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
|
||||||
|
shard = {tensor: llama_state_dict[tensor].contiguous() for tensor in tensors}
|
||||||
if save_safetensors:
|
if save_safetensors:
|
||||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||||
else:
|
else:
|
||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if not state_dict_split.is_sharded:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
|
||||||
else:
|
else:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
|
||||||
|
print(f"Model weights saved in {output_dir}.")
|
||||||
|
|
||||||
return str(torch_dtype).replace("torch.", "")
|
return str(torch_dtype).replace("torch.", "")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||||
@@ -135,7 +137,8 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
|
||||||
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_qwen(
|
def llamafy_qwen(
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 Tencent Inc. and the LlamaFactory team.
|
||||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||||
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||||
@@ -19,84 +18,75 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||||
from transformers.modeling_utils import (
|
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
shard_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
|
||||||
|
|
||||||
|
|
||||||
def block_expansion(
|
def block_expansion(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
num_expand: int,
|
num_expand: int,
|
||||||
shard_size: str = "2GB",
|
shard_size: str = "5GB",
|
||||||
save_safetensors: bool = True,
|
save_safetensors: bool = True,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
|
||||||
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||||
"""
|
"""
|
||||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
num_layers = getattr(config, "num_hidden_layers")
|
num_layers = getattr(config, "num_hidden_layers")
|
||||||
|
if num_layers % num_expand != 0:
|
||||||
|
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||||
|
|
||||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
|
print(f"Expanding model of {num_layers} layers to {num_layers + num_expand} layers.")
|
||||||
if save_safetensors:
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
|
model_name_or_path, torch_dtype="auto", device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True
|
||||||
|
|
||||||
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name_or_path,
|
|
||||||
config=config,
|
|
||||||
torch_dtype="auto",
|
|
||||||
trust_remote_code=True,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
)
|
)
|
||||||
state_dict = model.state_dict()
|
assert isinstance(model, PreTrainedModel) # type hint
|
||||||
|
if save_safetensors and getattr(model.config, "tie_word_embeddings", False):
|
||||||
if num_layers % num_expand != 0:
|
del model.lm_head # safetensors does not allow shared weights
|
||||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
|
||||||
|
|
||||||
split = num_layers // num_expand
|
split = num_layers // num_expand
|
||||||
layer_cnt = 0
|
layer_cnt = 0
|
||||||
output_state_dict = OrderedDict()
|
state_dict = model.state_dict()
|
||||||
|
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||||
|
|
||||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} copied from layer {i}.")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
if (i + 1) % split == 0:
|
if (i + 1) % split == 0:
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
if "down_proj" in key or "o_proj" in key:
|
if "down_proj" in key or "o_proj" in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||||
else:
|
else:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||||
|
|
||||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} expanded from layer {i}.")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
@@ -104,26 +94,34 @@ def block_expansion(
|
|||||||
output_state_dict[key] = value
|
output_state_dict[key] = value
|
||||||
|
|
||||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||||
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
output_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
|
||||||
|
)
|
||||||
|
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
|
||||||
|
shard = {tensor: output_state_dict[tensor].contiguous() for tensor in tensors}
|
||||||
if save_safetensors:
|
if save_safetensors:
|
||||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||||
else:
|
else:
|
||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if not state_dict_split.is_sharded:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
|
||||||
else:
|
else:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
|
||||||
|
print(f"Model weights saved in {output_dir}.")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("finetuning_type: freeze")
|
print("finetuning_type: freeze")
|
||||||
print("freeze_trainable_layers: {}".format(num_expand))
|
print(f"freeze_trainable_layers: {num_expand}")
|
||||||
print("use_llama_pro: true")
|
print("use_llama_pro: true")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
|
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
|
||||||
@@ -70,19 +69,19 @@ def quantize_loftq(
|
|||||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(loftq_dir))
|
print(f"Adapter weights saved in {loftq_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
print(f"adapter_name_or_path: {loftq_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("quantization_bit: {}".format(loftq_bits))
|
print(f"quantization_bit: {loftq_bits}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
|
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
|
||||||
@@ -54,7 +53,7 @@ def quantize_pissa(
|
|||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=lora_target,
|
target_modules=lora_target,
|
||||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init PiSSA model
|
# Init PiSSA model
|
||||||
@@ -65,17 +64,17 @@ def quantize_pissa(
|
|||||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(pissa_dir))
|
print(f"Adapter weights saved in {pissa_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
print(f"adapter_name_or_path: {pissa_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("pissa_init: false")
|
print("pissa_init: false")
|
||||||
print("pissa_convert: true")
|
print("pissa_convert: true")
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 Microsoft Corporation and the LlamaFactory team.
|
||||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||||
# https://www.deepspeed.ai/tutorials/flops-profiler/
|
# https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 imoneoi and the LlamaFactory team.
|
||||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# This code is inspired by the imoneoi's OpenChat library.
|
# This code is inspired by the imoneoi's OpenChat library.
|
||||||
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
|
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
|
||||||
@@ -23,9 +22,9 @@ import fire
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_train_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
@@ -42,7 +41,7 @@ def calculate_lr(
|
|||||||
dataset: str = "alpaca_en_demo",
|
dataset: str = "alpaca_en_demo",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
cutoff_len: int = 2048, # i.e. maximum input length during training
|
||||||
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
|
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
|
||||||
packing: bool = False,
|
packing: bool = False,
|
||||||
):
|
):
|
||||||
@@ -60,6 +59,7 @@ def calculate_lr(
|
|||||||
template=template,
|
template=template,
|
||||||
cutoff_len=cutoff_len,
|
cutoff_len=cutoff_len,
|
||||||
packing=packing,
|
packing=packing,
|
||||||
|
preprocessing_num_workers=16,
|
||||||
output_dir="dummy_dir",
|
output_dir="dummy_dir",
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
@@ -72,24 +72,25 @@ def calculate_lr(
|
|||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||||
|
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
valid_tokens, total_tokens = 0, 0
|
valid_tokens, total_tokens = 0, 0
|
||||||
for batch in tqdm(dataloader):
|
for batch in tqdm(dataloader, desc="Collecting valid tokens"):
|
||||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||||
total_tokens += torch.numel(batch["labels"])
|
total_tokens += torch.numel(batch["labels"])
|
||||||
|
|
||||||
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
|
||||||
valid_ratio = valid_tokens / total_tokens
|
valid_ratio = valid_tokens / total_tokens
|
||||||
batch_valid_len = batch_max_len * valid_ratio
|
token_batch_size = cutoff_len * batch_size * valid_ratio
|
||||||
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
|
||||||
lr = lr / 6.0 if is_mistral_or_gemma else lr
|
lr = lr / 6.0 if is_mistral_or_gemma else lr
|
||||||
print(
|
print(
|
||||||
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
|
||||||
lr, valid_ratio * 100, batch_valid_len
|
lr, valid_ratio * 100, token_batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 the LlamaFactory team.
|
||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
|
|||||||
elif "4090" in device_name:
|
elif "4090" in device_name:
|
||||||
return 98 * 1e12 * world_size
|
return 98 * 1e12 * world_size
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
raise NotImplementedError(f"Device not supported: {device_name}.")
|
||||||
|
|
||||||
|
|
||||||
def calculate_mfu(
|
def calculate_mfu(
|
||||||
@@ -140,24 +139,26 @@ def calculate_mfu(
|
|||||||
"bf16": True,
|
"bf16": True,
|
||||||
}
|
}
|
||||||
if deepspeed_stage in [2, 3]:
|
if deepspeed_stage in [2, 3]:
|
||||||
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
|
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||||
|
|
||||||
run_exp(args)
|
run_exp(args)
|
||||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
|
||||||
result = json.load(f)
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
else:
|
else:
|
||||||
world_size = 1
|
world_size = 1
|
||||||
|
|
||||||
total_batch_size = batch_size * world_size
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
mfu_value = (
|
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||||
result["train_steps_per_second"]
|
result = json.load(f)
|
||||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
|
||||||
/ compute_device_flops(world_size)
|
total_batch_size = batch_size * world_size
|
||||||
)
|
mfu_value = (
|
||||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
result["train_steps_per_second"]
|
||||||
|
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||||
|
/ compute_device_flops(world_size)
|
||||||
|
)
|
||||||
|
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 the LlamaFactory team.
|
||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -21,16 +20,16 @@ import fire
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_train_args
|
||||||
from llamafactory.model import load_model, load_tokenizer
|
from llamafactory.model import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for pairwise data.
|
Data collator for pairwise data.
|
||||||
"""
|
"""
|
||||||
@@ -40,36 +39,39 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Pads batched data to the longest sequence in the batch.
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
|
||||||
the last n examples represent rejected examples.
|
|
||||||
"""
|
"""
|
||||||
chosen_features = []
|
chosen_features = []
|
||||||
for feature in features:
|
for feature in features:
|
||||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
|
chosen_features.append(
|
||||||
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
|
{
|
||||||
attention_mask = [1] * (prompt_len + answer_len)
|
"input_ids": feature["chosen_input_ids"],
|
||||||
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
|
"attention_mask": feature["chosen_attention_mask"],
|
||||||
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
|
"labels": feature["chosen_input_ids"] if self.train_on_prompt else feature["chosen_labels"],
|
||||||
|
"images": feature["images"],
|
||||||
|
"videos": feature["videos"],
|
||||||
|
"audios": feature["audios"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return super().__call__(chosen_features)
|
return super().__call__(chosen_features)
|
||||||
|
|
||||||
|
|
||||||
def calculate_ppl(
|
def calculate_ppl(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
save_name: str,
|
save_name: str = "ppl.json",
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
stage: Literal["pt", "sft", "rm"] = "sft",
|
stage: Literal["pt", "sft", "rm"] = "sft",
|
||||||
dataset: str = "alpaca_en_demo",
|
dataset: str = "alpaca_en_demo",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 1024,
|
cutoff_len: int = 2048,
|
||||||
max_samples: Optional[int] = None,
|
max_samples: Optional[int] = None,
|
||||||
train_on_prompt: bool = False,
|
train_on_prompt: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates the ppl on the dataset of the pre-trained models.
|
Calculates the ppl on the dataset of the pre-trained models.
|
||||||
Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
|
Usage: export CUDA_VISIBLE_DEVICES=0
|
||||||
|
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
|
||||||
"""
|
"""
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
@@ -81,6 +83,7 @@ def calculate_ppl(
|
|||||||
cutoff_len=cutoff_len,
|
cutoff_len=cutoff_len,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
train_on_prompt=train_on_prompt,
|
train_on_prompt=train_on_prompt,
|
||||||
|
preprocessing_num_workers=16,
|
||||||
output_dir="dummy_dir",
|
output_dir="dummy_dir",
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
@@ -94,13 +97,15 @@ def calculate_ppl(
|
|||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||||
|
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
|
||||||
|
)
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
data_collator = PairwiseDataCollatorWithPadding(
|
data_collator = PairwiseDataCollatorWithPadding(
|
||||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
@@ -108,7 +113,7 @@ def calculate_ppl(
|
|||||||
perplexities = []
|
perplexities = []
|
||||||
batch: Dict[str, "torch.Tensor"]
|
batch: Dict[str, "torch.Tensor"]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(dataloader):
|
for batch in tqdm(dataloader, desc="Computing perplexities"):
|
||||||
batch = batch.to(model.device)
|
batch = batch.to(model.device)
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
||||||
@@ -125,8 +130,8 @@ def calculate_ppl(
|
|||||||
with open(save_name, "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(perplexities, f, indent=2)
|
json.dump(perplexities, f, indent=2)
|
||||||
|
|
||||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||||
print("Perplexities have been saved at {}.".format(save_name))
|
print(f"Perplexities have been saved at {save_name}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# coding=utf-8
|
# Copyright 2025 the LlamaFactory team.
|
||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -32,7 +31,8 @@ def length_cdf(
|
|||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates the distribution of the input lengths in the dataset.
|
Calculates the distribution of the input lengths in the dataset.
|
||||||
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
|
Usage: export CUDA_VISIBLE_DEVICES=0
|
||||||
|
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
|
||||||
"""
|
"""
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
@@ -42,6 +42,7 @@ def length_cdf(
|
|||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
template=template,
|
template=template,
|
||||||
cutoff_len=1_000_000,
|
cutoff_len=1_000_000,
|
||||||
|
preprocessing_num_workers=16,
|
||||||
output_dir="dummy_dir",
|
output_dir="dummy_dir",
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
@@ -52,7 +53,7 @@ def length_cdf(
|
|||||||
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
|
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
|
||||||
total_num = len(trainset)
|
total_num = len(trainset)
|
||||||
length_dict = defaultdict(int)
|
length_dict = defaultdict(int)
|
||||||
for sample in tqdm(trainset["input_ids"]):
|
for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"):
|
||||||
length_dict[len(sample) // interval * interval] += 1
|
length_dict[len(sample) // interval * interval] += 1
|
||||||
|
|
||||||
length_tuples = list(length_dict.items())
|
length_tuples = list(length_dict.items())
|
||||||
@@ -61,7 +62,7 @@ def length_cdf(
|
|||||||
for length, count in length_tuples:
|
for length, count in length_tuples:
|
||||||
count_accu += count
|
count_accu += count
|
||||||
prob_accu += count / total_num * 100
|
prob_accu += count / total_num * 100
|
||||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
151
scripts/vllm_infer.py
Normal file
151
scripts/vllm_infer.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
|
from llamafactory.extras.misc import check_version, get_device_count
|
||||||
|
from llamafactory.extras.packages import is_vllm_available
|
||||||
|
from llamafactory.hparams import get_infer_args
|
||||||
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_infer(
|
||||||
|
model_name_or_path: str,
|
||||||
|
adapter_name_or_path: str = None,
|
||||||
|
dataset: str = "alpaca_en_demo",
|
||||||
|
dataset_dir: str = "data",
|
||||||
|
template: str = "default",
|
||||||
|
cutoff_len: int = 2048,
|
||||||
|
max_samples: Optional[int] = None,
|
||||||
|
vllm_config: str = "{}",
|
||||||
|
save_name: str = "generated_predictions.jsonl",
|
||||||
|
temperature: float = 0.95,
|
||||||
|
top_p: float = 0.7,
|
||||||
|
top_k: int = 50,
|
||||||
|
max_new_tokens: int = 1024,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
pipeline_parallel_size: int = 1,
|
||||||
|
image_max_pixels: int = 768 * 768,
|
||||||
|
image_min_pixels: int = 32 * 32,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||||
|
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||||
|
"""
|
||||||
|
check_version("vllm>=0.4.3,<=0.7.3")
|
||||||
|
if pipeline_parallel_size > get_device_count():
|
||||||
|
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
|
||||||
|
|
||||||
|
model_args, data_args, _, generating_args = get_infer_args(
|
||||||
|
dict(
|
||||||
|
model_name_or_path=model_name_or_path,
|
||||||
|
adapter_name_or_path=adapter_name_or_path,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
template=template,
|
||||||
|
cutoff_len=cutoff_len,
|
||||||
|
max_samples=max_samples,
|
||||||
|
preprocessing_num_workers=16,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
|
||||||
|
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
|
||||||
|
|
||||||
|
inputs, prompts, labels = [], [], []
|
||||||
|
for sample in dataset_module["train_dataset"]:
|
||||||
|
if sample["images"]:
|
||||||
|
multi_modal_data = {
|
||||||
|
"image": template_obj.mm_plugin._regularize_images(
|
||||||
|
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||||
|
)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
multi_modal_data = None
|
||||||
|
|
||||||
|
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
|
||||||
|
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
|
||||||
|
labels.append(
|
||||||
|
tokenizer.decode(
|
||||||
|
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
||||||
|
temperature=generating_args.temperature,
|
||||||
|
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||||
|
top_k=generating_args.top_k or -1, # top_k must > 0
|
||||||
|
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
||||||
|
max_tokens=generating_args.max_new_tokens,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||||
|
else:
|
||||||
|
lora_request = None
|
||||||
|
|
||||||
|
engine_args = {
|
||||||
|
"model": model_args.model_name_or_path,
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"dtype": model_args.infer_dtype,
|
||||||
|
"max_model_len": cutoff_len + max_new_tokens,
|
||||||
|
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
|
||||||
|
"pipeline_parallel_size": pipeline_parallel_size,
|
||||||
|
"disable_log_stats": True,
|
||||||
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
|
}
|
||||||
|
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
|
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
|
||||||
|
|
||||||
|
if isinstance(model_args.vllm_config, dict):
|
||||||
|
engine_args.update(model_args.vllm_config)
|
||||||
|
|
||||||
|
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
|
||||||
|
preds = [result.outputs[0].text for result in results]
|
||||||
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
|
for text, pred, label in zip(prompts, preds, labels):
|
||||||
|
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
print("*" * 70)
|
||||||
|
print(f"{len(prompts)} generated results have been saved at {save_name}.")
|
||||||
|
print("*" * 70)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(vllm_infer)
|
||||||
37
setup.py
37
setup.py
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -20,7 +20,7 @@ from setuptools import find_packages, setup
|
|||||||
|
|
||||||
|
|
||||||
def get_version() -> str:
|
def get_version() -> str:
|
||||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||||
(version,) = re.findall(pattern, file_content)
|
(version,) = re.findall(pattern, file_content)
|
||||||
@@ -28,7 +28,7 @@ def get_version() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_requires() -> List[str]:
|
def get_requires() -> List[str]:
|
||||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
with open("requirements.txt", encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||||
return lines
|
return lines
|
||||||
@@ -36,7 +36,7 @@ def get_requires() -> List[str]:
|
|||||||
|
|
||||||
def get_console_scripts() -> List[str]:
|
def get_console_scripts() -> List[str]:
|
||||||
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
|
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
|
||||||
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
|
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
|
||||||
console_scripts.append("lmf = llamafactory.cli:main")
|
console_scripts.append("lmf = llamafactory.cli:main")
|
||||||
|
|
||||||
return console_scripts
|
return console_scripts
|
||||||
@@ -44,9 +44,9 @@ def get_console_scripts() -> List[str]:
|
|||||||
|
|
||||||
extra_require = {
|
extra_require = {
|
||||||
"torch": ["torch>=1.13.1"],
|
"torch": ["torch>=1.13.1"],
|
||||||
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
"deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
|
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
|
||||||
"liger-kernel": ["liger-kernel"],
|
"liger-kernel": ["liger-kernel"],
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
"hqq": ["hqq"],
|
"hqq": ["hqq"],
|
||||||
@@ -54,13 +54,26 @@ extra_require = {
|
|||||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||||
"awq": ["autoawq"],
|
"awq": ["autoawq"],
|
||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
"vllm": ["vllm>=0.4.3,<=0.6.0"],
|
"vllm": ["vllm>=0.4.3,<=0.7.3"],
|
||||||
"galore": ["galore-torch"],
|
"galore": ["galore-torch"],
|
||||||
|
"apollo": ["apollo-torch"],
|
||||||
"badam": ["badam>=1.2.1"],
|
"badam": ["badam>=1.2.1"],
|
||||||
"adam-mini": ["adam-mini"],
|
"adam-mini": ["adam-mini"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
|
"minicpm_v": [
|
||||||
|
"soundfile",
|
||||||
|
"torchvision",
|
||||||
|
"torchaudio",
|
||||||
|
"vector_quantize_pytorch",
|
||||||
|
"vocos",
|
||||||
|
"msgpack",
|
||||||
|
"referencing",
|
||||||
|
"jsonschema_specifications",
|
||||||
|
],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"dev": ["ruff", "pytest"],
|
"openmind": ["openmind"],
|
||||||
|
"swanlab": ["swanlab"],
|
||||||
|
"dev": ["pre-commit", "ruff", "pytest"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -69,16 +82,16 @@ def main():
|
|||||||
name="llamafactory",
|
name="llamafactory",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
author="hiyouga",
|
author="hiyouga",
|
||||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
author_email="hiyouga AT buaa.edu.cn",
|
||||||
description="Easy-to-use LLM fine-tuning framework",
|
description="Easy-to-use LLM fine-tuning framework",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||||
license="Apache 2.0 License",
|
license="Apache 2.0 License",
|
||||||
url="https://github.com/hiyouga/LLaMA-Factory",
|
url="https://github.com/hiyouga/LLaMA-Factory",
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
python_requires=">=3.8.0",
|
python_requires=">=3.9.0",
|
||||||
install_requires=get_requires(),
|
install_requires=get_requires(),
|
||||||
extras_require=extra_require,
|
extras_require=extra_require,
|
||||||
entry_points={"console_scripts": get_console_scripts()},
|
entry_points={"console_scripts": get_console_scripts()},
|
||||||
@@ -90,10 +103,10 @@ def main():
|
|||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
|
|||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -20,17 +20,17 @@ Level:
|
|||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
main:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||||
datasets>=2.16.0,<=2.21.0
|
datasets>=2.16.0,<=3.2.0
|
||||||
accelerate>=0.30.1,<=0.33.0
|
accelerate>=0.34.0,<=1.2.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
attention:
|
attention:
|
||||||
transformers>=4.42.4 (gemma+fa2)
|
transformers>=4.42.4 (gemma+fa2)
|
||||||
longlora:
|
longlora:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<4.48.0
|
||||||
packing:
|
packing:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.43.0
|
||||||
|
|
||||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||||
Enable VRAM recording: RECORD_VRAM=1
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
|
|||||||
Force using torchrun: FORCE_TORCHRUN=1
|
Force using torchrun: FORCE_TORCHRUN=1
|
||||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||||
|
Use openmind: USE_OPENMIND_HUB=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .extras.env import VERSION
|
from .extras.env import VERSION
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -21,6 +21,7 @@ from typing import Optional
|
|||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
|
from ..extras.constants import EngineName
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
||||||
from .chat import (
|
from .chat import (
|
||||||
@@ -60,7 +61,7 @@ async def sweeper() -> None:
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
||||||
if chat_model.engine_type == "huggingface":
|
if chat_model.engine.name == EngineName.HF:
|
||||||
asyncio.create_task(sweeper())
|
asyncio.create_task(sweeper())
|
||||||
|
|
||||||
yield
|
yield
|
||||||
@@ -68,7 +69,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
|
|||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
||||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@@ -77,7 +78,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
api_key = os.environ.get("API_KEY", None)
|
api_key = os.getenv("API_KEY")
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||||
@@ -91,7 +92,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
dependencies=[Depends(verify_api_key)],
|
dependencies=[Depends(verify_api_key)],
|
||||||
)
|
)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||||
return ModelList(data=[model_card])
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
@@ -106,7 +107,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = create_stream_chat_completion_response(request, chat_model)
|
generate = create_stream_chat_completion_response(request, chat_model)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
|
||||||
else:
|
else:
|
||||||
return await create_chat_completion_response(request, chat_model)
|
return await create_chat_completion_response(request, chat_model)
|
||||||
|
|
||||||
@@ -128,7 +129,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
def run_api() -> None:
|
def run_api() -> None:
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -21,7 +21,9 @@ import uuid
|
|||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
|
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||||
|
from ..extras.misc import is_env_enabled
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
from .common import dictify, jsonify
|
from .common import dictify, jsonify
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
@@ -57,7 +59,7 @@ if TYPE_CHECKING:
|
|||||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
ROLE_MAPPING = {
|
ROLE_MAPPING = {
|
||||||
Role.USER: DataRole.USER.value,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
@@ -69,8 +71,9 @@ ROLE_MAPPING = {
|
|||||||
|
|
||||||
def _process_request(
|
def _process_request(
|
||||||
request: "ChatCompletionRequest",
|
request: "ChatCompletionRequest",
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
if is_env_enabled("API_VERBOSE", "1"):
|
||||||
|
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
@@ -84,7 +87,7 @@ def _process_request(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
input_messages = []
|
input_messages = []
|
||||||
image = None
|
images = []
|
||||||
for i, message in enumerate(request.messages):
|
for i, message in enumerate(request.messages):
|
||||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
@@ -99,10 +102,12 @@ def _process_request(
|
|||||||
content = json.dumps(tool_calls, ensure_ascii=False)
|
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
|
text_content = ""
|
||||||
for input_item in message.content:
|
for input_item in message.content:
|
||||||
if input_item.type == "text":
|
if input_item.type == "text":
|
||||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
text_content += input_item.text
|
||||||
else:
|
else:
|
||||||
|
text_content += IMAGE_PLACEHOLDER
|
||||||
image_url = input_item.image_url.url
|
image_url = input_item.image_url.url
|
||||||
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
|
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
|
||||||
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
|
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
|
||||||
@@ -111,7 +116,9 @@ def _process_request(
|
|||||||
else: # web uri
|
else: # web uri
|
||||||
image_stream = requests.get(image_url, stream=True).raw
|
image_stream = requests.get(image_url, stream=True).raw
|
||||||
|
|
||||||
image = Image.open(image_stream).convert("RGB")
|
images.append(Image.open(image_stream).convert("RGB"))
|
||||||
|
|
||||||
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
|
||||||
else:
|
else:
|
||||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||||
|
|
||||||
@@ -124,7 +131,7 @@ def _process_request(
|
|||||||
else:
|
else:
|
||||||
tools = None
|
tools = None
|
||||||
|
|
||||||
return input_messages, system, tools, image
|
return input_messages, system, tools, images or None
|
||||||
|
|
||||||
|
|
||||||
def _create_stream_chat_completion_chunk(
|
def _create_stream_chat_completion_chunk(
|
||||||
@@ -142,13 +149,13 @@ def _create_stream_chat_completion_chunk(
|
|||||||
async def create_chat_completion_response(
|
async def create_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> "ChatCompletionResponse":
|
) -> "ChatCompletionResponse":
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
responses = await chat_model.achat(
|
responses = await chat_model.achat(
|
||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@@ -168,8 +175,8 @@ async def create_chat_completion_response(
|
|||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tool in result:
|
for tool in result:
|
||||||
function = Function(name=tool[0], arguments=tool[1])
|
function = Function(name=tool.name, arguments=tool.arguments)
|
||||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||||
|
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
finish_reason = Finish.TOOL
|
finish_reason = Finish.TOOL
|
||||||
@@ -193,8 +200,8 @@ async def create_chat_completion_response(
|
|||||||
async def create_stream_chat_completion_response(
|
async def create_stream_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
if tools:
|
if tools:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
@@ -208,7 +215,7 @@ async def create_stream_chat_completion_response(
|
|||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@@ -229,8 +236,9 @@ async def create_stream_chat_completion_response(
|
|||||||
async def create_score_evaluation_response(
|
async def create_score_evaluation_response(
|
||||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||||
) -> "ScoreEvaluationResponse":
|
) -> "ScoreEvaluationResponse":
|
||||||
|
score_id = f"scoreval-{uuid.uuid4().hex}"
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -22,7 +22,8 @@ if TYPE_CHECKING:
|
|||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
|
|
||||||
from ..data import Template
|
from ..data import Template
|
||||||
from ..data.mm_plugin import ImageInput, VideoInput
|
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
|
from ..extras.constants import EngineName
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -41,6 +42,7 @@ class BaseEngine(ABC):
|
|||||||
Must implements async methods: chat(), stream_chat() and get_scores().
|
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
name: "EngineName"
|
||||||
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||||
tokenizer: "PreTrainedTokenizer"
|
tokenizer: "PreTrainedTokenizer"
|
||||||
can_generate: bool
|
can_generate: bool
|
||||||
@@ -66,8 +68,9 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
@@ -81,8 +84,9 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import os
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
|
|
||||||
|
from ..extras.constants import EngineName
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
from .hf_engine import HuggingfaceEngine
|
from .hf_engine import HuggingfaceEngine
|
||||||
@@ -27,7 +28,7 @@ from .vllm_engine import VllmEngine
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..data.mm_plugin import ImageInput, VideoInput
|
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
@@ -47,13 +48,12 @@ class ChatModel:
|
|||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
self.engine_type = model_args.infer_backend
|
if model_args.infer_backend == EngineName.HF:
|
||||||
if model_args.infer_backend == "huggingface":
|
|
||||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
elif model_args.infer_backend == "vllm":
|
elif model_args.infer_backend == EngineName.VLLM:
|
||||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||||
|
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
@@ -64,15 +64,16 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
Gets a list of responses of the chat model.
|
Gets a list of responses of the chat model.
|
||||||
"""
|
"""
|
||||||
task = asyncio.run_coroutine_threadsafe(
|
task = asyncio.run_coroutine_threadsafe(
|
||||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
|
||||||
)
|
)
|
||||||
return task.result()
|
return task.result()
|
||||||
|
|
||||||
@@ -81,28 +82,30 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
Asynchronously gets a list of responses of the chat model.
|
Asynchronously gets a list of responses of the chat model.
|
||||||
"""
|
"""
|
||||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||||
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
r"""
|
r"""
|
||||||
Gets the response token-by-token of the chat model.
|
Gets the response token-by-token of the chat model.
|
||||||
"""
|
"""
|
||||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||||
@@ -115,14 +118,17 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
r"""
|
r"""
|
||||||
Asynchronously gets the response token-by-token of the chat model.
|
Asynchronously gets the response token-by-token of the chat model.
|
||||||
"""
|
"""
|
||||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
async for new_token in self.engine.stream_chat(
|
||||||
|
messages, system, tools, images, videos, audios, **input_kwargs
|
||||||
|
):
|
||||||
yield new_token
|
yield new_token
|
||||||
|
|
||||||
def get_scores(
|
def get_scores(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2025 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras import logging
|
||||||
from ..extras.logging import get_logger
|
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
||||||
from ..extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
@@ -35,11 +35,11 @@ if TYPE_CHECKING:
|
|||||||
from trl import PreTrainedModelWrapper
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
from ..data import Template
|
from ..data import Template
|
||||||
from ..data.mm_plugin import ImageInput, VideoInput
|
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceEngine(BaseEngine):
|
class HuggingfaceEngine(BaseEngine):
|
||||||
@@ -50,6 +50,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.name = EngineName.HF
|
||||||
self.can_generate = finetuning_args.stage == "sft"
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
self.tokenizer = tokenizer_module["tokenizer"]
|
self.tokenizer = tokenizer_module["tokenizer"]
|
||||||
@@ -63,11 +64,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
try:
|
try:
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.warning("There is no current event loop, creating a new one.")
|
logger.warning_rank0_once("There is no current event loop, creating a new one.")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
@@ -79,29 +80,41 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
|
||||||
if image is not None:
|
if images is not None:
|
||||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
|
|
||||||
if video is not None:
|
if videos is not None:
|
||||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
||||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||||
|
|
||||||
|
if audios is not None:
|
||||||
|
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
|
||||||
|
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
|
||||||
|
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
|
||||||
|
|
||||||
messages = template.mm_plugin.process_messages(
|
messages = template.mm_plugin.process_messages(
|
||||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
|
||||||
)
|
)
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or generating_args["default_system"]
|
system = system or generating_args["default_system"]
|
||||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||||
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
|
prompt_ids,
|
||||||
|
None,
|
||||||
|
mm_input_dict["images"],
|
||||||
|
mm_input_dict["videos"],
|
||||||
|
mm_input_dict["audios"],
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
)
|
)
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
@@ -114,12 +127,13 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
||||||
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
||||||
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
||||||
|
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||||
|
|
||||||
generating_args = generating_args.copy()
|
generating_args = generating_args.copy()
|
||||||
generating_args.update(
|
generating_args.update(
|
||||||
@@ -133,7 +147,10 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
if repetition_penalty is not None
|
if repetition_penalty is not None
|
||||||
else generating_args["repetition_penalty"],
|
else generating_args["repetition_penalty"],
|
||||||
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
||||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
skip_special_tokens=skip_special_tokens
|
||||||
|
if skip_special_tokens is not None
|
||||||
|
else generating_args["skip_special_tokens"],
|
||||||
|
eos_token_id=template.get_stop_token_ids(tokenizer),
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -164,10 +181,32 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
|
||||||
gen_kwargs[key] = value.to(model.device)
|
value = torch.stack(value) # assume they have same sizes
|
||||||
|
elif (
|
||||||
|
isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
|
||||||
|
): # for minicpmv inputs
|
||||||
|
value = torch.stack([torch.stack(v) for v in value])
|
||||||
|
elif not isinstance(value, torch.Tensor):
|
||||||
|
value = torch.tensor(value)
|
||||||
|
|
||||||
|
if torch.is_floating_point(value): # cast data dtype for paligemma
|
||||||
|
value = value.to(model.dtype)
|
||||||
|
|
||||||
|
if key == "second_per_grid_ts": # qwen2.5vl special case
|
||||||
|
gen_kwargs[key] = value.tolist()
|
||||||
|
else:
|
||||||
|
gen_kwargs[key] = value.to(model.device)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
|
||||||
|
gen_kwargs["input_ids"] = inputs
|
||||||
|
gen_kwargs["tokenizer"] = tokenizer
|
||||||
|
if "audio_feature_lens" in mm_inputs:
|
||||||
|
gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
|
||||||
|
|
||||||
|
gen_kwargs.pop("image_sizes", None)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
@@ -182,16 +221,35 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
audios,
|
||||||
|
input_kwargs,
|
||||||
)
|
)
|
||||||
generate_output = model.generate(**gen_kwargs)
|
generate_output = model.generate(**gen_kwargs)
|
||||||
|
if isinstance(generate_output, tuple):
|
||||||
|
generate_output = generate_output[1][0] # post-process the minicpm_o output
|
||||||
|
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
response = tokenizer.batch_decode(
|
||||||
|
response_ids,
|
||||||
|
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
)
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(response)):
|
for i in range(len(response)):
|
||||||
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
||||||
@@ -218,14 +276,30 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Callable[[], str]:
|
) -> Callable[[], str]:
|
||||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
audios,
|
||||||
|
input_kwargs,
|
||||||
|
)
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
|
||||||
)
|
)
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
@@ -246,29 +320,18 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||||
device = getattr(model.pretrained_model, "device", "cuda")
|
device = getattr(model.pretrained_model, "device", "cuda")
|
||||||
inputs = tokenizer(
|
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
||||||
batch_input,
|
batch_input,
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_special_tokens=True,
|
add_special_tokens=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) == "chatglm":
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for i in range(input_ids.size(0)):
|
|
||||||
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
|
||||||
scores.append(values[i, end_index].nan_to_num().item())
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -277,8 +340,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@@ -294,8 +358,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
video,
|
videos,
|
||||||
|
audios,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
@@ -308,8 +373,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
|
audios: Optional[Sequence["AudioInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@@ -325,8 +391,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
video,
|
videos,
|
||||||
|
audios,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user