Compare commits
619 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57354fc990 | ||
|
|
89f240805c | ||
|
|
27bbea886c | ||
|
|
3ec3dda33a | ||
|
|
ae9f338bf7 | ||
|
|
bf44f76dc7 | ||
|
|
c18581f0a4 | ||
|
|
9f6c5c4798 | ||
|
|
7bc03ac986 | ||
|
|
85d7e4f4ab | ||
|
|
bf69747f40 | ||
|
|
f1146bf7b6 | ||
|
|
9efd1fec90 | ||
|
|
3b91839a55 | ||
|
|
bc4421eeef | ||
|
|
5003820a6a | ||
|
|
cd2485f28d | ||
|
|
918a367378 | ||
|
|
3d35aeca72 | ||
|
|
53b1e5fd1d | ||
|
|
b852c895cf | ||
|
|
aaa7ed8712 | ||
|
|
205aca5b03 | ||
|
|
87b1f851f1 | ||
|
|
fca814b30d | ||
|
|
a20c2b6ecf | ||
|
|
fee94e1c54 | ||
|
|
047a596542 | ||
|
|
3d45606984 | ||
|
|
310c107d56 | ||
|
|
089e4d9e96 | ||
|
|
ae56c3cf49 | ||
|
|
0a0288a286 | ||
|
|
25da686758 | ||
|
|
e2da3cc9fa | ||
|
|
c42e5cf401 | ||
|
|
9943cd1c96 | ||
|
|
1e6f96508a | ||
|
|
d401974f69 | ||
|
|
09b2dbe859 | ||
|
|
7f8ef8c132 | ||
|
|
fcb6283a72 | ||
|
|
0027f46ccc | ||
|
|
967a27695e | ||
|
|
3ce8a326c6 | ||
|
|
91b56b7baf | ||
|
|
e2fa961302 | ||
|
|
87d6d7dc61 | ||
|
|
00019e2ca4 | ||
|
|
b104739d63 | ||
|
|
b238d1aa04 | ||
|
|
aa497d5d96 | ||
|
|
fecf04b2f4 | ||
|
|
3f157e2f6f | ||
|
|
c7c558562e | ||
|
|
c2ea5fb618 | ||
|
|
fa9c32bb8d | ||
|
|
c610deb5a2 | ||
|
|
2bb3255e74 | ||
|
|
b28b74c71e | ||
|
|
1ed921bff7 | ||
|
|
80f634cc95 | ||
|
|
a3eb5e200c | ||
|
|
2d02c0e22d | ||
|
|
093eda2ad6 | ||
|
|
dbaf621f57 | ||
|
|
ceb701c2d4 | ||
|
|
29ad3783f5 | ||
|
|
fa2386e73c | ||
|
|
e0045e8386 | ||
|
|
b94c941196 | ||
|
|
ba66ac084f | ||
|
|
83479c9ef0 | ||
|
|
df8ac15ef0 | ||
|
|
8cea5cd967 | ||
|
|
a2d7d6a518 | ||
|
|
a63e624eca | ||
|
|
8596c321ce | ||
|
|
54cd799aa0 | ||
|
|
8185eb1890 | ||
|
|
03213984ec | ||
|
|
aeeee9d4b5 | ||
|
|
c8a1fb99bf | ||
|
|
f0181a41ff | ||
|
|
f6b06d0c6f | ||
|
|
1047217f78 | ||
|
|
16a9a44849 | ||
|
|
58fb24ce41 | ||
|
|
a9afffa246 | ||
|
|
1fdd053022 | ||
|
|
0a833968a0 | ||
|
|
58b681de78 | ||
|
|
22d5fc5f4c | ||
|
|
cc0119f698 | ||
|
|
580cedebde | ||
|
|
43bd1b070c | ||
|
|
42aa9c65be | ||
|
|
b0b87fa33f | ||
|
|
22912eba1a | ||
|
|
e2748fa967 | ||
|
|
248d5daaff | ||
|
|
8f5921692e | ||
|
|
e880eb8844 | ||
|
|
dc076c4e52 | ||
|
|
8306e93ef3 | ||
|
|
6a2cd129c0 | ||
|
|
30d7f6a22e | ||
|
|
5440ebbae6 | ||
|
|
22dbe694e9 | ||
|
|
64ac6ca396 | ||
|
|
377d37fa7f | ||
|
|
55296744a8 | ||
|
|
d0889012c2 | ||
|
|
3a8b2890eb | ||
|
|
5b2284a51d | ||
|
|
4807d8a4ef | ||
|
|
c6e1313977 | ||
|
|
66819fd3ee | ||
|
|
bd85e370be | ||
|
|
cc097174cc | ||
|
|
7d135bbdb8 | ||
|
|
4845a76535 | ||
|
|
67645c0db8 | ||
|
|
f463b3f038 | ||
|
|
01defc2779 | ||
|
|
c9e77ab352 | ||
|
|
c3de160d1c | ||
|
|
3693d7b571 | ||
|
|
a63144c28f | ||
|
|
2b3b0473cd | ||
|
|
9d929897ce | ||
|
|
313a5e1494 | ||
|
|
74dd25224a | ||
|
|
c7efc7f2ed | ||
|
|
c71c78da50 | ||
|
|
f4897da009 | ||
|
|
a6951db970 | ||
|
|
9d27aaa38f | ||
|
|
3b19b6f31b | ||
|
|
5b15ca0b0b | ||
|
|
aad79127e6 | ||
|
|
c42dcab32b | ||
|
|
be519c84d9 | ||
|
|
b2dc6dc59a | ||
|
|
9df626dc18 | ||
|
|
8d4b9200a1 | ||
|
|
7806df46ba | ||
|
|
bba026a212 | ||
|
|
6e111eb29f | ||
|
|
2b69ae0eb2 | ||
|
|
13d73574ef | ||
|
|
bc264807ae | ||
|
|
f9815dd20a | ||
|
|
1f58943b32 | ||
|
|
6476507429 | ||
|
|
35862d19ec | ||
|
|
1272cb00df | ||
|
|
e9ac26db4c | ||
|
|
20ee1d2e19 | ||
|
|
cbc1dd0c88 | ||
|
|
870bbabbc4 | ||
|
|
8fd84c375e | ||
|
|
32b5364051 | ||
|
|
cf72aec098 | ||
|
|
87849d12d2 | ||
|
|
a19512436f | ||
|
|
6c89d93aea | ||
|
|
345f40a660 | ||
|
|
8b9a814653 | ||
|
|
05fabf9095 | ||
|
|
95eede911a | ||
|
|
7bc7f7d673 | ||
|
|
054fdbe186 | ||
|
|
f0f80819a0 | ||
|
|
e702678252 | ||
|
|
553579986a | ||
|
|
622cb04f27 | ||
|
|
f3ba11a432 | ||
|
|
8b1f53bca5 | ||
|
|
ac25fef80e | ||
|
|
15f819d273 | ||
|
|
f2d1c43d28 | ||
|
|
464acc7d6c | ||
|
|
a96c5da737 | ||
|
|
28d09b81c9 | ||
|
|
a769d0e3d4 | ||
|
|
1b98b5e65c | ||
|
|
3cc5408da7 | ||
|
|
689f5c4554 | ||
|
|
ab5d042cd3 | ||
|
|
4d43317aa1 | ||
|
|
ed3b0c5b40 | ||
|
|
67a97794ee | ||
|
|
2c7c93cb9b | ||
|
|
4d4fe08d14 | ||
|
|
85a919b6f7 | ||
|
|
fe2abe20fc | ||
|
|
12444720db | ||
|
|
510faf5805 | ||
|
|
722e01c8ab | ||
|
|
6050e6cff9 | ||
|
|
c8abbe4fc3 | ||
|
|
f2881c9d4a | ||
|
|
1ded3abdf1 | ||
|
|
e641f1215a | ||
|
|
ca736bcab7 | ||
|
|
bddb2646bd | ||
|
|
e4c57f54f8 | ||
|
|
6de82ca843 | ||
|
|
b2c02df555 | ||
|
|
ca86d6361e | ||
|
|
b6fb00e046 | ||
|
|
86c84972c8 | ||
|
|
9390927875 | ||
|
|
c4a585f232 | ||
|
|
300feb3245 | ||
|
|
cacafb0038 | ||
|
|
6509114259 | ||
|
|
7d4cb79822 | ||
|
|
b867e164fe | ||
|
|
26bbfc084d | ||
|
|
c376eed31d | ||
|
|
7c595abc38 | ||
|
|
c428ab68d8 | ||
|
|
968b9f1852 | ||
|
|
018266c66e | ||
|
|
111c644bf1 | ||
|
|
de72d1f0e7 | ||
|
|
8bfb856923 | ||
|
|
8fdbaab95d | ||
|
|
a01668bbe8 | ||
|
|
3385616a37 | ||
|
|
1f0d89328d | ||
|
|
a7feab45d5 | ||
|
|
f34322afd7 | ||
|
|
3815fa40b7 | ||
|
|
c43050b3fa | ||
|
|
3e152872ad | ||
|
|
ae6ad55758 | ||
|
|
0118a2fc04 | ||
|
|
4dd81976f4 | ||
|
|
2b4da8baf6 | ||
|
|
7d1b4071e8 | ||
|
|
8fc5377f50 | ||
|
|
e5812f261d | ||
|
|
f7e85cd7de | ||
|
|
749395420b | ||
|
|
7d536d1d75 | ||
|
|
7fd0d2fc2f | ||
|
|
ec696bbcdd | ||
|
|
df24345d65 | ||
|
|
386dd26097 | ||
|
|
514f976cc1 | ||
|
|
66b870fd08 | ||
|
|
24d3c7e378 | ||
|
|
484128b641 | ||
|
|
588ea95732 | ||
|
|
800567cde7 | ||
|
|
7a3ba5a25d | ||
|
|
dfff411e1a | ||
|
|
e20baa4218 | ||
|
|
d1ab9b501a | ||
|
|
3cbc9109ea | ||
|
|
3259397f89 | ||
|
|
eb5af3d90b | ||
|
|
b6810b209a | ||
|
|
158e0e1f63 | ||
|
|
294a103ead | ||
|
|
7f71276ad8 | ||
|
|
93d4570a59 | ||
|
|
527ba2eb2e | ||
|
|
3021b31cf3 | ||
|
|
9f2427907e | ||
|
|
570ce100c1 | ||
|
|
27547355e6 | ||
|
|
c5ef52a67a | ||
|
|
b48b47d519 | ||
|
|
9bdba2f6a8 | ||
|
|
d6ce902d80 | ||
|
|
ce6dcf3600 | ||
|
|
e7f92d16d8 | ||
|
|
abd26f5f67 | ||
|
|
4d35ace75e | ||
|
|
72222d1598 | ||
|
|
26d914b8fc | ||
|
|
7b01c0676c | ||
|
|
571a9b8669 | ||
|
|
ed35eb1e9e | ||
|
|
d291e0d60d | ||
|
|
1874d579c5 | ||
|
|
c692339020 | ||
|
|
2c1eef34cb | ||
|
|
af178cbcd1 | ||
|
|
5d85be31ca | ||
|
|
372b71c847 | ||
|
|
41a9c415e1 | ||
|
|
915e32a5f8 | ||
|
|
f4dd429cbf | ||
|
|
7435cde2ef | ||
|
|
7056087e92 | ||
|
|
fed7ae5661 | ||
|
|
5019c6148b | ||
|
|
2e1396cd6b | ||
|
|
b5e9df5df8 | ||
|
|
3622856994 | ||
|
|
7367c6ec21 | ||
|
|
6579ec8c4c | ||
|
|
a7fbae47d5 | ||
|
|
f203a9d78e | ||
|
|
bae73e676c | ||
|
|
806e1061d4 | ||
|
|
f920091667 | ||
|
|
801979f779 | ||
|
|
df2d32e7aa | ||
|
|
60cf12727b | ||
|
|
7621526d22 | ||
|
|
559b84dceb | ||
|
|
7e4c5d4bb3 | ||
|
|
2a4ed6610e | ||
|
|
1d8e9c7897 | ||
|
|
43654028eb | ||
|
|
2f6fc27c8b | ||
|
|
d789b667d7 | ||
|
|
66a1abac6a | ||
|
|
665db18661 | ||
|
|
30d97ca879 | ||
|
|
c62a6ca59d | ||
|
|
77c2c7076b | ||
|
|
7466fd4387 | ||
|
|
c1369a1ec9 | ||
|
|
d677fe053d | ||
|
|
7c6785d3df | ||
|
|
77341ee3c4 | ||
|
|
5b4b60cfb5 | ||
|
|
0f3d54d8a0 | ||
|
|
7272792f65 | ||
|
|
4cc8e16595 | ||
|
|
ca5a759f94 | ||
|
|
be51e56a2e | ||
|
|
3a9171e275 | ||
|
|
bd0f3b4050 | ||
|
|
206a8364d4 | ||
|
|
097d031066 | ||
|
|
2674b42b59 | ||
|
|
edf2e51bbc | ||
|
|
47877acc2a | ||
|
|
d111a324bc | ||
|
|
388f0a6e05 | ||
|
|
8c13c02c55 | ||
|
|
a101fde917 | ||
|
|
1f4373b6e5 | ||
|
|
525747b472 | ||
|
|
472f12c985 | ||
|
|
b681f24f43 | ||
|
|
fd02b089b6 | ||
|
|
57d4c4a4f8 | ||
|
|
3595d26846 | ||
|
|
22a79c169d | ||
|
|
75dfe259cf | ||
|
|
2e257d6af0 | ||
|
|
e734222373 | ||
|
|
6a351b9912 | ||
|
|
cfc04aa162 | ||
|
|
943c795318 | ||
|
|
7fb61bad04 | ||
|
|
47efcdb1dd | ||
|
|
59cbce1a46 | ||
|
|
7e755e9cac | ||
|
|
9d1e2c3c1f | ||
|
|
5af32ce705 | ||
|
|
4e8861e653 | ||
|
|
d4d7ffb17c | ||
|
|
46f834ec75 | ||
|
|
6ec64a7e56 | ||
|
|
d71446e387 | ||
|
|
eada49e56b | ||
|
|
8f42d7df56 | ||
|
|
33a90b9026 | ||
|
|
710902b0d0 | ||
|
|
7b4f5d3b21 | ||
|
|
13093963b1 | ||
|
|
2e477e7458 | ||
|
|
4b6252151e | ||
|
|
f3765d1996 | ||
|
|
1f5cdd66b7 | ||
|
|
5b0ddbb835 | ||
|
|
4f92b56f06 | ||
|
|
a1f6ff92be | ||
|
|
ef98e91618 | ||
|
|
9fdf800750 | ||
|
|
32c698e4c2 | ||
|
|
75e80fa820 | ||
|
|
f8329bc632 | ||
|
|
9f74d36ba4 | ||
|
|
fc2435f135 | ||
|
|
0636519ba3 | ||
|
|
573bf03a6f | ||
|
|
9e529be4e7 | ||
|
|
7af4ffa6cc | ||
|
|
5b67ccd1c6 | ||
|
|
5166dbbcd3 | ||
|
|
21adb09730 | ||
|
|
28b5f656db | ||
|
|
68ee2d512f | ||
|
|
a5f7e0efc6 | ||
|
|
211038584a | ||
|
|
ff5ba97970 | ||
|
|
27f2c3cae1 | ||
|
|
48f0819327 | ||
|
|
5c6d88e91c | ||
|
|
0a04d9470f | ||
|
|
f0408c0dde | ||
|
|
a041f4a111 | ||
|
|
cdf9dae53e | ||
|
|
1917f431f5 | ||
|
|
a770afbff2 | ||
|
|
b1a5bf025b | ||
|
|
adff3e5050 | ||
|
|
0e88c5754f | ||
|
|
3fff875f99 | ||
|
|
e2d9ab3591 | ||
|
|
3db5cf44ea | ||
|
|
994b9089e9 | ||
|
|
4c1513a845 | ||
|
|
86e009b504 | ||
|
|
c1e1918db1 | ||
|
|
341225a405 | ||
|
|
8c93921952 | ||
|
|
45367105fc | ||
|
|
df71359069 | ||
|
|
a03d14a9a6 | ||
|
|
41d7ca395e | ||
|
|
757573bec1 | ||
|
|
16d655b119 | ||
|
|
f6483de197 | ||
|
|
da34411bf2 | ||
|
|
1891b64072 | ||
|
|
a14069acf8 | ||
|
|
0ea708c226 | ||
|
|
cb474c7b11 | ||
|
|
e4d11a117b | ||
|
|
68365045b4 | ||
|
|
502555b65d | ||
|
|
0bc52c0aae | ||
|
|
6bf2663b8e | ||
|
|
d337de668e | ||
|
|
ec372f91e9 | ||
|
|
20b1bd8c54 | ||
|
|
ee17741591 | ||
|
|
93a6925ec5 | ||
|
|
47405a8e8a | ||
|
|
54ba30c47f | ||
|
|
b92214f78b | ||
|
|
71e4404c0d | ||
|
|
5ab997d484 | ||
|
|
6e7048831b | ||
|
|
97cd932c19 | ||
|
|
dfc7a7d5cd | ||
|
|
27e13a8371 | ||
|
|
bf6ad1fbed | ||
|
|
bc71380b59 | ||
|
|
137c87ff60 | ||
|
|
485b8dc18b | ||
|
|
875f9078d1 | ||
|
|
d3bfcbd3af | ||
|
|
e36db692e7 | ||
|
|
460a40756c | ||
|
|
18057e14ef | ||
|
|
025c8fe302 | ||
|
|
446129ca7a | ||
|
|
834c4e8ad9 | ||
|
|
11d961cf3c | ||
|
|
00b93d8b2f | ||
|
|
281fd5bb89 | ||
|
|
cb10050cb9 | ||
|
|
2935c4cddb | ||
|
|
0d6ec70c6f | ||
|
|
74777b4ded | ||
|
|
5f2bd04799 | ||
|
|
9a1a5f9778 | ||
|
|
edc8aefa59 | ||
|
|
ee1c786a12 | ||
|
|
a3e4f2b716 | ||
|
|
6685f1fb9e | ||
|
|
c89ff328f6 | ||
|
|
c6f1bc65c0 | ||
|
|
0f43c61229 | ||
|
|
8567dab167 | ||
|
|
0517d7bee5 | ||
|
|
5bc0b9b31c | ||
|
|
3d219b91b9 | ||
|
|
a90c6306f8 | ||
|
|
60558388ec | ||
|
|
b29a7f8cd6 | ||
|
|
a1501591e8 | ||
|
|
1408aa078d | ||
|
|
5acaa476d6 | ||
|
|
8ac4f87c91 | ||
|
|
14d3001824 | ||
|
|
1ac9389ddc | ||
|
|
0b0e27c2f1 | ||
|
|
fd1199cce4 | ||
|
|
3c9eda8265 | ||
|
|
6622cdb43f | ||
|
|
49c28a7dab | ||
|
|
a42671c2d7 | ||
|
|
f17ab6ad92 | ||
|
|
ca548af2a2 | ||
|
|
579997688f | ||
|
|
e6ba7ef3e6 | ||
|
|
20fdf177e8 | ||
|
|
f0b01803ea | ||
|
|
f5c4841ff2 | ||
|
|
1e01283d81 | ||
|
|
2196448c21 | ||
|
|
96a81ce89d | ||
|
|
a715490c2a | ||
|
|
973cf8e980 | ||
|
|
4357e42391 | ||
|
|
884b49e662 | ||
|
|
38c94d2e9c | ||
|
|
67d2eb6b2a | ||
|
|
b670fb57db | ||
|
|
188b4be64d | ||
|
|
889c042ecd | ||
|
|
3c4f8eaa55 | ||
|
|
6a75d57060 | ||
|
|
fda2cf677b | ||
|
|
cfdf5a5a78 | ||
|
|
a1437c15f7 | ||
|
|
42e7489713 | ||
|
|
024760f866 | ||
|
|
46f0189e88 | ||
|
|
edc7498111 | ||
|
|
9103fdf866 | ||
|
|
95bf795de4 | ||
|
|
bf99223a80 | ||
|
|
9caf9b6f91 | ||
|
|
727c7b0dc6 | ||
|
|
13d184b280 | ||
|
|
12a91774b0 | ||
|
|
88018000ac | ||
|
|
f6eda1c35d | ||
|
|
a2ebdbc112 | ||
|
|
e930a42083 | ||
|
|
4b123f49cb | ||
|
|
556eca918d | ||
|
|
31fcd03f3c | ||
|
|
89d9dd5aa5 | ||
|
|
d1aad72826 | ||
|
|
8e5b4bddf4 | ||
|
|
5a7cb9af4e | ||
|
|
d1cda4ec68 | ||
|
|
8aaf1185a5 | ||
|
|
b46bd07119 | ||
|
|
08fa707085 | ||
|
|
72ba29d81a | ||
|
|
cf2dc4c444 | ||
|
|
d82d86e16d | ||
|
|
bde31d8600 | ||
|
|
e115d55585 | ||
|
|
daea86e047 | ||
|
|
a4f69d8914 | ||
|
|
98f382fda3 | ||
|
|
cd899734f3 | ||
|
|
f51b435bcf | ||
|
|
0f82a55305 | ||
|
|
9fd7a410bb | ||
|
|
98fb3d015a | ||
|
|
bfb2ad7c79 | ||
|
|
135bfbf7c1 | ||
|
|
c6b17ebc20 | ||
|
|
b55eb30474 | ||
|
|
cec2f1fc00 | ||
|
|
8367ec03a7 | ||
|
|
37013f8068 | ||
|
|
8360544d65 | ||
|
|
b5cdef43a1 | ||
|
|
2e5d521ed8 | ||
|
|
dbe35d52d1 | ||
|
|
8bcdb6f52c | ||
|
|
5cfcb8262e | ||
|
|
0b331a318b | ||
|
|
5d6cf55208 | ||
|
|
9a1ec19845 | ||
|
|
a79e93f335 | ||
|
|
abcb94a738 | ||
|
|
a4f2d5aa6f | ||
|
|
6b738d1c89 | ||
|
|
f4c518b370 | ||
|
|
d475dd3809 | ||
|
|
5675c47a01 | ||
|
|
16e950454e | ||
|
|
2926265a14 | ||
|
|
af2607de1a | ||
|
|
826d7808b4 | ||
|
|
4c89aca243 | ||
|
|
43a065bb07 | ||
|
|
4513a2cc75 | ||
|
|
f29c1ac6e5 | ||
|
|
05abe47c8b | ||
|
|
6c185a2c57 | ||
|
|
af2cb33bb2 | ||
|
|
f16a4a8264 | ||
|
|
b232552d42 | ||
|
|
0edccc11a5 | ||
|
|
b2f5c0e0db | ||
|
|
5f5d4c1923 | ||
|
|
a7d7f79855 | ||
|
|
fa3150548e | ||
|
|
c7479751e8 | ||
|
|
870a54ac84 | ||
|
|
12fcfc2b72 | ||
|
|
95ae30f678 | ||
|
|
7408e778ca | ||
|
|
ba303fd1aa | ||
|
|
dd7a1dbfae | ||
|
|
f91fe10985 | ||
|
|
c7ab302c69 |
@@ -4,10 +4,12 @@
|
|||||||
.venv
|
.venv
|
||||||
cache
|
cache
|
||||||
data
|
data
|
||||||
|
docker
|
||||||
|
saves
|
||||||
hf_cache
|
hf_cache
|
||||||
|
ms_cache
|
||||||
|
om_cache
|
||||||
output
|
output
|
||||||
examples
|
|
||||||
.dockerignore
|
.dockerignore
|
||||||
.gitattributes
|
.gitattributes
|
||||||
.gitignore
|
.gitignore
|
||||||
Dockerfile
|
|
||||||
|
|||||||
37
.env.local
Normal file
37
.env.local
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Note: actually we do not support .env, just for reference
|
||||||
|
# api
|
||||||
|
API_HOST=
|
||||||
|
API_PORT=
|
||||||
|
API_KEY=
|
||||||
|
API_MODEL_NAME=
|
||||||
|
FASTAPI_ROOT_PATH=
|
||||||
|
MAX_CONCURRENT=
|
||||||
|
# general
|
||||||
|
DISABLE_VERSION_CHECK=
|
||||||
|
FORCE_CHECK_IMPORTS=
|
||||||
|
LLAMAFACTORY_VERBOSITY=
|
||||||
|
USE_MODELSCOPE_HUB=
|
||||||
|
USE_OPENMIND_HUB=
|
||||||
|
RECORD_VRAM=
|
||||||
|
# torchrun
|
||||||
|
FORCE_TORCHRUN=
|
||||||
|
MASTER_ADDR=
|
||||||
|
MASTER_PORT=
|
||||||
|
NNODES=
|
||||||
|
NODE_RANK=
|
||||||
|
NPROC_PER_NODE=
|
||||||
|
# wandb
|
||||||
|
WANDB_DISABLED=
|
||||||
|
WANDB_PROJECT=
|
||||||
|
WANDB_API_KEY=
|
||||||
|
# gradio ui
|
||||||
|
GRADIO_SHARE=
|
||||||
|
GRADIO_SERVER_NAME=
|
||||||
|
GRADIO_SERVER_PORT=
|
||||||
|
GRADIO_ROOT_PATH=
|
||||||
|
GRADIO_IPV6=
|
||||||
|
# setup
|
||||||
|
ENABLE_SHORT_CONSOLE=1
|
||||||
|
# reserved (do not use)
|
||||||
|
LLAMABOARD_ENABLED=
|
||||||
|
LLAMABOARD_WORKDIR=
|
||||||
46
.github/CONTRIBUTING.md
vendored
46
.github/CONTRIBUTING.md
vendored
@@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
|
|||||||
### Style guide
|
### Style guide
|
||||||
|
|
||||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||||
|
|
||||||
|
### Create a Pull Request
|
||||||
|
|
||||||
|
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||||
|
|
||||||
|
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:[username]/LLaMA-Factory.git
|
||||||
|
cd LLaMA-Factory
|
||||||
|
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create a new branch to hold your development changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git checkout -b dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Set up a development environment by running the following command in a virtual environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
|
||||||
|
|
||||||
|
5. Check code before commit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make commit
|
||||||
|
make style && make quality
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Submit changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add .
|
||||||
|
git commit -m "commit message"
|
||||||
|
git fetch upstream
|
||||||
|
git rebase upstream/main
|
||||||
|
git push -u origin dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).
|
||||||
|
|||||||
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,13 +1,19 @@
|
|||||||
name: "\U0001F41B Bug / Help"
|
name: "\U0001F41B Bug / Help"
|
||||||
description: Create a report to help us improve the LLaMA Factory
|
description: Create a report to help us improve the LLaMA Factory
|
||||||
body:
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Issues included in **FAQs** or those with **insufficient** information may be closed without a response.
|
||||||
|
包含在**常见问题**内或提供信息**不完整**的 issues 可能不会被回复。
|
||||||
|
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
id: reminder
|
id: reminder
|
||||||
attributes:
|
attributes:
|
||||||
label: Reminder
|
label: Reminder
|
||||||
description: |
|
description: |
|
||||||
Please ensure you have read the README carefully and searched the existing issues.
|
Please ensure you have read the README carefully and searched the existing issues (including FAQs).
|
||||||
请确保您已经认真阅读了 README 并且搜索过现有的 Issue。
|
请确保您已经认真阅读了 README 并且搜索过现有的 issues(包括常见问题)。
|
||||||
|
|
||||||
options:
|
options:
|
||||||
- label: I have read the README and searched the existing issues.
|
- label: I have read the README and searched the existing issues.
|
||||||
|
|||||||
15
.github/workflows/label_issue.yml
vendored
15
.github/workflows/label_issue.yml
vendored
@@ -9,9 +9,22 @@ jobs:
|
|||||||
label_issue:
|
label_issue:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- env:
|
- env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
ISSUE_URL: ${{ github.event.issue.html_url }}
|
ISSUE_URL: ${{ github.event.issue.html_url }}
|
||||||
|
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||||
run: |
|
run: |
|
||||||
gh issue edit $ISSUE_URL --add-label "pending"
|
LABEL=pending
|
||||||
|
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
|
||||||
|
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
|
||||||
|
for KEYWORD in ${NPU_KEYWORDS[@]}; do
|
||||||
|
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
|
||||||
|
LABEL=pending,npu
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
gh issue edit $ISSUE_URL --add-label $LABEL
|
||||||
|
|||||||
6
.github/workflows/publish.yml
vendored
6
.github/workflows/publish.yml
vendored
@@ -26,15 +26,15 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.8"
|
python-version: "3.8"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install build
|
python -m pip install build
|
||||||
|
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: |
|
run: |
|
||||||
python -m build
|
python -m build
|
||||||
|
|
||||||
- name: Publish package
|
- name: Publish package
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
|||||||
30
.github/workflows/tests.yml
vendored
30
.github/workflows/tests.yml
vendored
@@ -3,14 +3,14 @@ name: tests
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- "**.py"
|
- "**.py"
|
||||||
- "requirements.txt"
|
- "requirements.txt"
|
||||||
- ".github/workflows/*.yml"
|
- ".github/workflows/*.yml"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- "**.py"
|
- "**.py"
|
||||||
- "requirements.txt"
|
- "requirements.txt"
|
||||||
@@ -18,7 +18,27 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
tests:
|
tests:
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version:
|
||||||
|
- "3.8" # TODO: remove py38 in next transformers release
|
||||||
|
- "3.9"
|
||||||
|
- "3.10"
|
||||||
|
- "3.11"
|
||||||
|
os:
|
||||||
|
- "ubuntu-latest"
|
||||||
|
- "windows-latest"
|
||||||
|
- "macos-13"
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
environment:
|
||||||
|
name: tests
|
||||||
|
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
OS_NAME: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -27,14 +47,14 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.8"
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: "pip"
|
cache: "pip"
|
||||||
cache-dependency-path: "setup.py"
|
cache-dependency-path: "setup.py"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install .[torch,dev]
|
python -m pip install ".[torch,dev]"
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -159,7 +159,15 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
# vscode
|
||||||
|
.vscode/
|
||||||
|
|
||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
user.config
|
ms_cache/
|
||||||
saves/
|
hf_cache/
|
||||||
|
om_cache/
|
||||||
cache/
|
cache/
|
||||||
|
config/
|
||||||
|
saves/
|
||||||
|
output/
|
||||||
|
wandb/
|
||||||
|
|||||||
28
.pre-commit-config.yaml
Normal file
28
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ['--maxkb=25000']
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
args: [--markdown-linebreak-ext=md]
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
args: ['--branch', 'main']
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.17.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: [--py38-plus]
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.9
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
- id: ruff-format
|
||||||
11
CITATION.cff
11
CITATION.cff
@@ -12,12 +12,16 @@ authors:
|
|||||||
given-names: "Yanhan"
|
given-names: "Yanhan"
|
||||||
- family-names: "Luo"
|
- family-names: "Luo"
|
||||||
given-names: "Zheyan"
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Feng"
|
||||||
|
given-names: "Zhangchi"
|
||||||
- family-names: "Ma"
|
- family-names: "Ma"
|
||||||
given-names: "Yongqiang"
|
given-names: "Yongqiang"
|
||||||
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
url: "https://arxiv.org/abs/2403.13372"
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
preferred-citation:
|
preferred-citation:
|
||||||
type: article
|
type: conference-paper
|
||||||
|
conference:
|
||||||
|
name: "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)"
|
||||||
authors:
|
authors:
|
||||||
- family-names: "Zheng"
|
- family-names: "Zheng"
|
||||||
given-names: "Yaowei"
|
given-names: "Yaowei"
|
||||||
@@ -29,9 +33,12 @@ preferred-citation:
|
|||||||
given-names: "Yanhan"
|
given-names: "Yanhan"
|
||||||
- family-names: "Luo"
|
- family-names: "Luo"
|
||||||
given-names: "Zheyan"
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Feng"
|
||||||
|
given-names: "Zhangchi"
|
||||||
- family-names: "Ma"
|
- family-names: "Ma"
|
||||||
given-names: "Yongqiang"
|
given-names: "Yongqiang"
|
||||||
journal: "arXiv preprint arXiv:2403.13372"
|
|
||||||
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
url: "https://arxiv.org/abs/2403.13372"
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
year: 2024
|
year: 2024
|
||||||
|
publisher: "Association for Computational Linguistics"
|
||||||
|
address: "Bangkok, Thailand"
|
||||||
|
|||||||
47
Dockerfile
47
Dockerfile
@@ -1,47 +0,0 @@
|
|||||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
|
||||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
|
||||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
|
||||||
|
|
||||||
# Define installation arguments
|
|
||||||
ARG INSTALL_BNB=false
|
|
||||||
ARG INSTALL_VLLM=false
|
|
||||||
ARG INSTALL_DEEPSPEED=false
|
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
|
||||||
|
|
||||||
# Set the working directory
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Install the requirements
|
|
||||||
COPY requirements.txt /app/
|
|
||||||
RUN pip config set global.index-url $PIP_INDEX
|
|
||||||
RUN python -m pip install --upgrade pip
|
|
||||||
RUN python -m pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Copy the rest of the application into the image
|
|
||||||
COPY . /app/
|
|
||||||
|
|
||||||
# Install the LLaMA Factory
|
|
||||||
RUN EXTRA_PACKAGES="metrics"; \
|
|
||||||
if [ "$INSTALL_BNB" = "true" ]; then \
|
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
|
|
||||||
fi; \
|
|
||||||
if [ "$INSTALL_VLLM" = "true" ]; then \
|
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
|
|
||||||
fi; \
|
|
||||||
if [ "$INSTALL_DEEPSPEED" = "true" ]; then \
|
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
|
||||||
fi; \
|
|
||||||
pip install -e .[$EXTRA_PACKAGES] && \
|
|
||||||
pip uninstall -y transformer-engine flash-attn
|
|
||||||
|
|
||||||
# Set up volumes
|
|
||||||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
|
||||||
|
|
||||||
# Expose port 7860 for the LLaMA Board
|
|
||||||
EXPOSE 7860
|
|
||||||
|
|
||||||
# Expose port 8000 for the API service
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
# Launch LLaMA Board
|
|
||||||
CMD [ "llamafactory-cli", "webui" ]
|
|
||||||
13
Makefile
13
Makefile
@@ -1,6 +1,13 @@
|
|||||||
.PHONY: quality style test
|
.PHONY: build commit quality style test
|
||||||
|
|
||||||
check_dirs := scripts src tests
|
check_dirs := scripts src tests setup.py
|
||||||
|
|
||||||
|
build:
|
||||||
|
pip install build && python -m build
|
||||||
|
|
||||||
|
commit:
|
||||||
|
pre-commit install
|
||||||
|
pre-commit run --all-files
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
@@ -11,4 +18,4 @@ style:
|
|||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
||||||
|
|||||||
353
README.md
353
README.md
@@ -4,7 +4,7 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#projects-using-llama-factory)
|
[](#projects-using-llama-factory)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@@ -12,22 +12,32 @@
|
|||||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
[](https://trendshift.io/repositories/4535)
|
[](https://trendshift.io/repositories/4535)
|
||||||
|
|
||||||
👋 Join our [WeChat](assets/wechat.jpg).
|
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
|
||||||
|
|
||||||
\[ English | [中文](README_zh.md) \]
|
\[ English | [中文](README_zh.md) \]
|
||||||
|
|
||||||
**Fine-tuning a large language model can be easy as...**
|
**Fine-tuning a large language model can be easy as...**
|
||||||
|
|
||||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6
|
https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
|
||||||
|
|
||||||
Choose your path:
|
Choose your path:
|
||||||
|
|
||||||
|
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
|
||||||
- **Local machine**: Please refer to [usage](#getting-started)
|
- **Local machine**: Please refer to [usage](#getting-started)
|
||||||
|
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||||
|
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
|
Recent activities:
|
||||||
|
|
||||||
|
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
@@ -46,11 +56,11 @@ Choose your path:
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
||||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
|
||||||
@@ -71,15 +81,27 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||||
|
|
||||||
|
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
||||||
|
|
||||||
|
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
||||||
|
|
||||||
|
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
|
||||||
|
|
||||||
|
[24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
|
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
|
||||||
|
|
||||||
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
||||||
|
|
||||||
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
|
||||||
|
|
||||||
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
|
|
||||||
|
|
||||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
@@ -91,7 +113,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage.
|
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
||||||
|
|
||||||
@@ -103,7 +125,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
|
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage.
|
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
|
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
|
||||||
|
|
||||||
@@ -119,7 +141,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
||||||
|
|
||||||
@@ -151,35 +173,40 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model | Model size | Template |
|
| Model | Model size | Template |
|
||||||
| --------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||||
@@ -203,6 +230,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||||||
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
|
||||||
|
|
||||||
## Provided Datasets
|
## Provided Datasets
|
||||||
|
|
||||||
<details><summary>Pre-training datasets</summary>
|
<details><summary>Pre-training datasets</summary>
|
||||||
@@ -262,7 +292,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||||||
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
|
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
|
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
@@ -279,6 +311,8 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||||||
|
|
||||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||||
|
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
|
||||||
|
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
|
||||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
@@ -299,20 +333,20 @@ huggingface-cli login
|
|||||||
| Mandatory | Minimum | Recommend |
|
| Mandatory | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.8 | 3.11 |
|
||||||
| torch | 1.13.1 | 2.3.0 |
|
| torch | 1.13.1 | 2.4.0 |
|
||||||
| transformers | 4.41.2 | 4.41.2 |
|
| transformers | 4.41.2 | 4.43.4 |
|
||||||
| datasets | 2.16.0 | 2.19.2 |
|
| datasets | 2.16.0 | 2.20.0 |
|
||||||
| accelerate | 0.30.1 | 0.30.1 |
|
| accelerate | 0.30.1 | 0.32.0 |
|
||||||
| peft | 0.11.1 | 0.11.1 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.4 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| Optional | Minimum | Recommend |
|
| Optional | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.4.3 |
|
| vllm | 0.4.3 | 0.5.0 |
|
||||||
| flash-attn | 2.3.0 | 2.5.9 |
|
| flash-attn | 2.3.0 | 2.6.3 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
@@ -341,7 +375,7 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||||
@@ -360,9 +394,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
|||||||
|
|
||||||
<details><summary>For Ascend NPU users</summary>
|
<details><summary>For Ascend NPU users</summary>
|
||||||
|
|
||||||
Join [NPU user group](assets/wechat_npu.jpg).
|
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||||
|
|
||||||
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# replace the url according to your CANN version and devices
|
# replace the url according to your CANN version and devices
|
||||||
@@ -385,20 +417,17 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
|
||||||
Docker image:
|
|
||||||
|
|
||||||
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
|
||||||
- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||||
|
|
||||||
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
||||||
|
|
||||||
|
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
|
||||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Please update `data/dataset_info.json` to use your custom dataset.
|
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||||
@@ -426,18 +455,47 @@ llamafactory-cli webui
|
|||||||
|
|
||||||
### Build Docker
|
### Build Docker
|
||||||
|
|
||||||
#### Use Docker
|
For CUDA users:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -f ./Dockerfile \
|
cd docker/docker-cuda/
|
||||||
|
docker compose up -d
|
||||||
|
docker compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
For Ascend NPU users:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker/docker-npu/
|
||||||
|
docker compose up -d
|
||||||
|
docker compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
For AMD ROCm users:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker/docker-rocm/
|
||||||
|
docker compose up -d
|
||||||
|
docker compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
<details><summary>Build without Docker Compose</summary>
|
||||||
|
|
||||||
|
For CUDA users:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||||
--build-arg INSTALL_BNB=false \
|
--build-arg INSTALL_BNB=false \
|
||||||
--build-arg INSTALL_VLLM=false \
|
--build-arg INSTALL_VLLM=false \
|
||||||
--build-arg INSTALL_DEEPSPEED=false \
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg INSTALL_FLASHATTN=false \
|
||||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
-t llamafactory:latest .
|
-t llamafactory:latest .
|
||||||
|
|
||||||
docker run -it --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@@ -445,20 +503,81 @@ docker run -it --gpus=all \
|
|||||||
--shm-size 16G \
|
--shm-size 16G \
|
||||||
--name llamafactory \
|
--name llamafactory \
|
||||||
llamafactory:latest
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Use Docker Compose
|
For Ascend NPU users:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d
|
# Choose docker image upon your environment
|
||||||
docker-compose exec llamafactory bash
|
docker build -f ./docker/docker-npu/Dockerfile \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
# Change `device` upon your resources
|
||||||
|
docker run -dit \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
|
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||||
|
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||||
|
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||||
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
|
--device /dev/davinci0 \
|
||||||
|
--device /dev/davinci_manager \
|
||||||
|
--device /dev/devmm_svm \
|
||||||
|
--device /dev/hisi_hdc \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llamafactory \
|
||||||
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For AMD ROCm users:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./docker/docker-rocm/Dockerfile \
|
||||||
|
--build-arg INSTALL_BNB=false \
|
||||||
|
--build-arg INSTALL_VLLM=false \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg INSTALL_FLASHATTN=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
docker run -dit \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-v ./saves:/app/saves \
|
||||||
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
|
--device /dev/kfd \
|
||||||
|
--device /dev/dri \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llamafactory \
|
||||||
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details><summary>Details about volume</summary>
|
<details><summary>Details about volume</summary>
|
||||||
|
|
||||||
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||||
- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
||||||
- output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
|
||||||
|
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||||
|
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -469,7 +588,9 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
|
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
||||||
|
>
|
||||||
|
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
|
||||||
|
|
||||||
### Download from ModelScope Hub
|
### Download from ModelScope Hub
|
||||||
|
|
||||||
@@ -481,6 +602,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
|||||||
|
|
||||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||||
|
|
||||||
|
### Download from Modelers Hub
|
||||||
|
|
||||||
|
You can also use Modelers Hub to download models and datasets.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
|
||||||
|
|
||||||
### Use W&B Logger
|
### Use W&B Logger
|
||||||
|
|
||||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||||
@@ -503,45 +634,95 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||||
|
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
||||||
|
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
||||||
|
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
||||||
|
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
||||||
|
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
||||||
|
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
||||||
|
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
||||||
|
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
||||||
|
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
||||||
|
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
||||||
|
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
||||||
|
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
||||||
|
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
||||||
|
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
||||||
|
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
||||||
|
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
||||||
|
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
||||||
|
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
||||||
|
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
||||||
|
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
||||||
|
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
||||||
|
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
||||||
|
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
||||||
|
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
||||||
|
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
||||||
|
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
|
||||||
|
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
|
||||||
|
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
|
||||||
|
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
|
||||||
|
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
|
||||||
|
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
|
||||||
|
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
|
||||||
|
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
|
||||||
|
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
|
||||||
|
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
|
||||||
|
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
|
||||||
|
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
|
||||||
|
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
|
||||||
|
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
|
||||||
|
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
|
||||||
|
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
|
||||||
|
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
||||||
|
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
||||||
|
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
||||||
|
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
|
||||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||||
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||||
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||||
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -549,17 +730,19 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If this work is helpful, please kindly cite as:
|
If this work is helpful, please kindly cite as:
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{zheng2024llamafactory,
|
@inproceedings{zheng2024llamafactory,
|
||||||
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
||||||
journal={arXiv preprint arXiv:2403.13372},
|
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
||||||
|
address={Bangkok, Thailand},
|
||||||
|
publisher={Association for Computational Linguistics},
|
||||||
year={2024},
|
year={2024},
|
||||||
url={http://arxiv.org/abs/2403.13372}
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
|
|||||||
352
README_zh.md
352
README_zh.md
@@ -4,7 +4,7 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#使用了-llama-factory-的项目)
|
[](#使用了-llama-factory-的项目)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@@ -12,22 +12,33 @@
|
|||||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
[](https://trendshift.io/repositories/4535)
|
[](https://trendshift.io/repositories/4535)
|
||||||
|
|
||||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||||
|
|
||||||
\[ [English](README.md) | 中文 \]
|
\[ [English](README.md) | 中文 \]
|
||||||
|
|
||||||
**微调大模型可以像这样轻松…**
|
**微调大模型可以像这样轻松…**
|
||||||
|
|
||||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd-d76c6d0a6594
|
https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||||
|
|
||||||
选择你的打开方式:
|
选择你的打开方式:
|
||||||
|
|
||||||
|
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||||
|
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
|
||||||
- **本地机器**:请见[如何使用](#如何使用)
|
- **本地机器**:请见[如何使用](#如何使用)
|
||||||
|
- **PAI-DSW**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||||
|
- **Amazon SageMaker**:[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||||
|
|
||||||
|
近期活动:
|
||||||
|
|
||||||
|
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
@@ -46,11 +57,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
## 项目特色
|
## 项目特色
|
||||||
|
|
||||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
||||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
|
||||||
@@ -71,15 +82,27 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||||
|
|
||||||
|
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
||||||
|
|
||||||
|
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
||||||
|
|
||||||
|
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
|
||||||
|
|
||||||
|
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
|
||||||
|
|
||||||
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
||||||
|
|
||||||
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `paligemma` 模板进行微调使其获得对话能力。
|
||||||
|
|
||||||
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
|
|
||||||
|
|
||||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
@@ -91,7 +114,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||||
|
|
||||||
@@ -103,7 +126,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。
|
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。
|
||||||
|
|
||||||
@@ -151,35 +174,39 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
| 模型名 | 模型大小 | Template |
|
| 模型名 | 模型大小 | Template |
|
||||||
| --------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||||
@@ -203,6 +230,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 有关 PPO 的实现细节,请参考[此博客](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html)。
|
||||||
|
|
||||||
## 数据集
|
## 数据集
|
||||||
|
|
||||||
<details><summary>预训练数据集</summary>
|
<details><summary>预训练数据集</summary>
|
||||||
@@ -262,7 +292,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
|
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
|
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
@@ -279,6 +311,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||||
|
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
|
||||||
|
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
|
||||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
@@ -299,20 +333,20 @@ huggingface-cli login
|
|||||||
| 必需项 | 至少 | 推荐 |
|
| 必需项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.8 | 3.11 |
|
||||||
| torch | 1.13.1 | 2.3.0 |
|
| torch | 1.13.1 | 2.4.0 |
|
||||||
| transformers | 4.41.2 | 4.41.2 |
|
| transformers | 4.41.2 | 4.43.4 |
|
||||||
| datasets | 2.16.0 | 2.19.2 |
|
| datasets | 2.16.0 | 2.20.0 |
|
||||||
| accelerate | 0.30.1 | 0.30.1 |
|
| accelerate | 0.30.1 | 0.32.0 |
|
||||||
| peft | 0.11.1 | 0.11.1 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.4 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| 可选项 | 至少 | 推荐 |
|
| 可选项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.4.3 |
|
| vllm | 0.4.3 | 0.5.0 |
|
||||||
| flash-attn | 2.3.0 | 2.5.9 |
|
| flash-attn | 2.3.0 | 2.6.3 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
@@ -341,7 +375,7 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||||
@@ -360,9 +394,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||||||
|
|
||||||
<details><summary>昇腾 NPU 用户指南</summary>
|
<details><summary>昇腾 NPU 用户指南</summary>
|
||||||
|
|
||||||
加入 [NPU 用户群](assets/wechat_npu.jpg)。
|
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||||
|
|
||||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||||
@@ -385,20 +417,17 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
|
||||||
Docker 镜像:
|
|
||||||
|
|
||||||
- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
|
||||||
- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||||
|
|
||||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
||||||
|
|
||||||
|
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 数据准备
|
### 数据准备
|
||||||
|
|
||||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||||
@@ -426,18 +455,47 @@ llamafactory-cli webui
|
|||||||
|
|
||||||
### 构建 Docker
|
### 构建 Docker
|
||||||
|
|
||||||
#### 使用 Docker
|
CUDA 用户:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -f ./Dockerfile \
|
cd docker/docker-cuda/
|
||||||
|
docker compose up -d
|
||||||
|
docker compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
昇腾 NPU 用户:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker/docker-npu/
|
||||||
|
docker compose up -d
|
||||||
|
docker compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
AMD ROCm 用户:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker/docker-rocm/
|
||||||
|
docker compose up -d
|
||||||
|
docker compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
<details><summary>不使用 Docker Compose 构建</summary>
|
||||||
|
|
||||||
|
CUDA 用户:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||||
--build-arg INSTALL_BNB=false \
|
--build-arg INSTALL_BNB=false \
|
||||||
--build-arg INSTALL_VLLM=false \
|
--build-arg INSTALL_VLLM=false \
|
||||||
--build-arg INSTALL_DEEPSPEED=false \
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg INSTALL_FLASHATTN=false \
|
||||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
-t llamafactory:latest .
|
-t llamafactory:latest .
|
||||||
|
|
||||||
docker run -it --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@@ -445,20 +503,81 @@ docker run -it --gpus=all \
|
|||||||
--shm-size 16G \
|
--shm-size 16G \
|
||||||
--name llamafactory \
|
--name llamafactory \
|
||||||
llamafactory:latest
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 Docker Compose
|
昇腾 NPU 用户:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d
|
# 根据您的环境选择镜像
|
||||||
docker-compose exec llamafactory bash
|
docker build -f ./docker/docker-npu/Dockerfile \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
# 根据您的资源更改 `device`
|
||||||
|
docker run -dit \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
|
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||||
|
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||||
|
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||||
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
|
--device /dev/davinci0 \
|
||||||
|
--device /dev/davinci_manager \
|
||||||
|
--device /dev/devmm_svm \
|
||||||
|
--device /dev/hisi_hdc \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llamafactory \
|
||||||
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
```
|
```
|
||||||
|
|
||||||
|
AMD ROCm 用户:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./docker/docker-rocm/Dockerfile \
|
||||||
|
--build-arg INSTALL_BNB=false \
|
||||||
|
--build-arg INSTALL_VLLM=false \
|
||||||
|
--build-arg INSTALL_DEEPSPEED=false \
|
||||||
|
--build-arg INSTALL_FLASHATTN=false \
|
||||||
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
|
-t llamafactory:latest .
|
||||||
|
|
||||||
|
docker run -dit \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-v ./saves:/app/saves \
|
||||||
|
-p 7860:7860 \
|
||||||
|
-p 8000:8000 \
|
||||||
|
--device /dev/kfd \
|
||||||
|
--device /dev/dri \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llamafactory \
|
||||||
|
llamafactory:latest
|
||||||
|
|
||||||
|
docker exec -it llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details><summary>数据卷详情</summary>
|
<details><summary>数据卷详情</summary>
|
||||||
|
|
||||||
- hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||||
- data:宿主机中存放数据集的文件夹路径。
|
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
||||||
- output:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
|
||||||
|
- `data`:宿主机中存放数据集的文件夹路径。
|
||||||
|
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -469,7 +588,9 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
|
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
||||||
|
>
|
||||||
|
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
|
||||||
|
|
||||||
### 从魔搭社区下载
|
### 从魔搭社区下载
|
||||||
|
|
||||||
@@ -481,6 +602,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||||||
|
|
||||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||||
|
|
||||||
|
### 从魔乐社区下载
|
||||||
|
|
||||||
|
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
|
||||||
|
```
|
||||||
|
|
||||||
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`。
|
||||||
|
|
||||||
### 使用 W&B 面板
|
### 使用 W&B 面板
|
||||||
|
|
||||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||||
@@ -503,45 +634,94 @@ run_name: test_run # 可选
|
|||||||
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||||
|
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
||||||
|
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
||||||
|
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
||||||
|
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
||||||
|
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
||||||
|
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
||||||
|
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
||||||
|
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
||||||
|
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
||||||
|
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
||||||
|
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
||||||
|
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
||||||
|
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
||||||
|
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
||||||
|
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
||||||
|
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
||||||
|
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
||||||
|
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
||||||
|
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
||||||
|
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
||||||
|
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
||||||
|
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
||||||
|
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
||||||
|
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
||||||
|
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
||||||
|
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
|
||||||
|
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
|
||||||
|
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
|
||||||
|
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
|
||||||
|
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
|
||||||
|
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
|
||||||
|
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
|
||||||
|
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
|
||||||
|
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
|
||||||
|
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
|
||||||
|
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
|
||||||
|
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
|
||||||
|
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
|
||||||
|
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
|
||||||
|
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
|
||||||
|
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
|
||||||
|
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
||||||
|
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
||||||
|
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
||||||
|
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
|
||||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||||
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||||
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||||
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -549,17 +729,19 @@ run_name: test_run # 可选
|
|||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{zheng2024llamafactory,
|
@inproceedings{zheng2024llamafactory,
|
||||||
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
||||||
journal={arXiv preprint arXiv:2403.13372},
|
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
||||||
|
address={Bangkok, Thailand},
|
||||||
|
publisher={Association for Computational Linguistics},
|
||||||
year={2024},
|
year={2024},
|
||||||
url={http://arxiv.org/abs/2403.13372}
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
|
|||||||
1630
assets/benchmark.svg
1630
assets/benchmark.svg
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 28 KiB |
195
data/README.md
195
data/README.md
@@ -11,8 +11,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
|||||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||||
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
||||||
"subset": "the name of the subset. (optional, default: None)",
|
"subset": "the name of the subset. (optional, default: None)",
|
||||||
|
"split": "the name of dataset split to be used. (optional, default: train)",
|
||||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||||
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
|
"num_samples": "the number of samples in the dataset to be used. (optional, default: None)",
|
||||||
"columns (optional)": {
|
"columns (optional)": {
|
||||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||||
@@ -22,6 +23,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
|||||||
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||||
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
||||||
"images": "the column name in the dataset containing the image inputs. (default: None)",
|
"images": "the column name in the dataset containing the image inputs. (default: None)",
|
||||||
|
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
|
||||||
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
|
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
|
||||||
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
|
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
|
||||||
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
|
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
|
||||||
@@ -106,7 +108,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
|
|
||||||
### Preference Dataset
|
### Preference Dataset
|
||||||
|
|
||||||
Preference datasets are used for reward modeling, DPO training and ORPO training.
|
Preference datasets are used for reward modeling, DPO training, ORPO and SimPO training.
|
||||||
|
|
||||||
It requires a better response in `chosen` column and a worse response in `rejected` column.
|
It requires a better response in `chosen` column and a worse response in `rejected` column.
|
||||||
|
|
||||||
@@ -138,67 +140,15 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
|
|
||||||
### KTO Dataset
|
### KTO Dataset
|
||||||
|
|
||||||
- [Example dataset](kto_en_demo.json)
|
An additional column `kto_tag` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
|
||||||
|
|
||||||
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
|
### Multimodal Image Dataset
|
||||||
|
|
||||||
```json
|
An additional column `images` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
|
||||||
[
|
|
||||||
{
|
|
||||||
"instruction": "human instruction (required)",
|
|
||||||
"input": "human input (optional)",
|
|
||||||
"output": "model response (required)",
|
|
||||||
"kto_tag": "human feedback [true/false] (required)"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
### Multimodal Video Dataset
|
||||||
|
|
||||||
```json
|
An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
|
||||||
"dataset_name": {
|
|
||||||
"file_name": "data.json",
|
|
||||||
"columns": {
|
|
||||||
"prompt": "instruction",
|
|
||||||
"query": "input",
|
|
||||||
"response": "output",
|
|
||||||
"kto_tag": "kto_tag"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multimodal Dataset
|
|
||||||
|
|
||||||
- [Example dataset](mllm_demo.json)
|
|
||||||
|
|
||||||
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image.
|
|
||||||
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"instruction": "human instruction (required)",
|
|
||||||
"input": "human input (optional)",
|
|
||||||
"output": "model response (required)",
|
|
||||||
"images": [
|
|
||||||
"image path (required)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
|
||||||
|
|
||||||
```json
|
|
||||||
"dataset_name": {
|
|
||||||
"file_name": "data.json",
|
|
||||||
"columns": {
|
|
||||||
"prompt": "instruction",
|
|
||||||
"query": "input",
|
|
||||||
"response": "output",
|
|
||||||
"images": "images"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Sharegpt Format
|
## Sharegpt Format
|
||||||
|
|
||||||
@@ -251,6 +201,10 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Pre-training Dataset
|
||||||
|
|
||||||
|
Not yet supported, please use the [alpaca](#alpaca-format) format.
|
||||||
|
|
||||||
### Preference Dataset
|
### Preference Dataset
|
||||||
|
|
||||||
- [Example dataset](dpo_en_demo.json)
|
- [Example dataset](dpo_en_demo.json)
|
||||||
@@ -301,6 +255,125 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### KTO Dataset
|
||||||
|
|
||||||
|
- [Example dataset](kto_en_demo.json)
|
||||||
|
|
||||||
|
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "human instruction"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "model response"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kto_tag": "human feedback [true/false] (required)"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"dataset_name": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"kto_tag": "kto_tag"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multimodal Image Dataset
|
||||||
|
|
||||||
|
- [Example dataset](mllm_demo.json)
|
||||||
|
|
||||||
|
Multimodal image datasets require a `images` column containing the paths to the input images.
|
||||||
|
|
||||||
|
The number of images should be identical to the `<image>` tokens in the conversations.
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "<image>human instruction"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "model response"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"image path (required)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"dataset_name": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"images": "images"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multimodal Video Dataset
|
||||||
|
|
||||||
|
- [Example dataset](mllm_video_demo.json)
|
||||||
|
|
||||||
|
Multimodal video datasets require a `videos` column containing the paths to the input videos.
|
||||||
|
|
||||||
|
The number of videos should be identical to the `<video>` tokens in the conversations.
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "<video>human instruction"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "model response"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"videos": [
|
||||||
|
"video path (required)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"dataset_name": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"videos": "videos"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### OpenAI Format
|
### OpenAI Format
|
||||||
|
|
||||||
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
|
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
|
||||||
@@ -344,7 +417,3 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
The KTO datasets and multimodal datasets in sharegpt format are similar to the alpaca format.
|
|
||||||
|
|
||||||
Pre-training datasets are **incompatible** with the sharegpt format.
|
|
||||||
|
|||||||
@@ -11,8 +11,9 @@
|
|||||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||||
"subset": "数据集子集的名称(可选,默认:None)",
|
"subset": "数据集子集的名称(可选,默认:None)",
|
||||||
|
"split": "所使用的数据集切分(可选,默认:train)",
|
||||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||||
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
"num_samples": "该数据集所使用的样本数量。(可选,默认:None)",
|
||||||
"columns(可选)": {
|
"columns(可选)": {
|
||||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||||
"query": "数据集代表请求的表头名称(默认:input)",
|
"query": "数据集代表请求的表头名称(默认:input)",
|
||||||
@@ -22,6 +23,7 @@
|
|||||||
"system": "数据集代表系统提示的表头名称(默认:None)",
|
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||||
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
||||||
"images": "数据集代表图像输入的表头名称(默认:None)",
|
"images": "数据集代表图像输入的表头名称(默认:None)",
|
||||||
|
"videos": "数据集代表视频输入的表头名称(默认:None)",
|
||||||
"chosen": "数据集代表更优回答的表头名称(默认:None)",
|
"chosen": "数据集代表更优回答的表头名称(默认:None)",
|
||||||
"rejected": "数据集代表更差回答的表头名称(默认:None)",
|
"rejected": "数据集代表更差回答的表头名称(默认:None)",
|
||||||
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
|
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
|
||||||
@@ -106,7 +108,7 @@
|
|||||||
|
|
||||||
### 偏好数据集
|
### 偏好数据集
|
||||||
|
|
||||||
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。
|
偏好数据集用于奖励模型训练、DPO 训练、ORPO 训练和 SimPO 训练。
|
||||||
|
|
||||||
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
|
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
|
||||||
|
|
||||||
@@ -138,67 +140,15 @@
|
|||||||
|
|
||||||
### KTO 数据集
|
### KTO 数据集
|
||||||
|
|
||||||
- [样例数据集](kto_en_demo.json)
|
KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
|
||||||
|
|
||||||
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
|
### 多模态图像数据集
|
||||||
|
|
||||||
```json
|
多模态图像数据集需要提供额外的 `images` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
|
||||||
[
|
|
||||||
{
|
|
||||||
"instruction": "人类指令(必填)",
|
|
||||||
"input": "人类输入(选填)",
|
|
||||||
"output": "模型回答(必填)",
|
|
||||||
"kto_tag": "人类反馈 [true/false](必填)"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
### 多模态视频数据集
|
||||||
|
|
||||||
```json
|
多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
|
||||||
"数据集名称": {
|
|
||||||
"file_name": "data.json",
|
|
||||||
"columns": {
|
|
||||||
"prompt": "instruction",
|
|
||||||
"query": "input",
|
|
||||||
"response": "output",
|
|
||||||
"kto_tag": "kto_tag"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 多模态数据集
|
|
||||||
|
|
||||||
- [样例数据集](mllm_demo.json)
|
|
||||||
|
|
||||||
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。
|
|
||||||
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"instruction": "人类指令(必填)",
|
|
||||||
"input": "人类输入(选填)",
|
|
||||||
"output": "模型回答(必填)",
|
|
||||||
"images": [
|
|
||||||
"图像路径(必填)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
|
||||||
|
|
||||||
```json
|
|
||||||
"数据集名称": {
|
|
||||||
"file_name": "data.json",
|
|
||||||
"columns": {
|
|
||||||
"prompt": "instruction",
|
|
||||||
"query": "input",
|
|
||||||
"response": "output",
|
|
||||||
"images": "images"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Sharegpt 格式
|
## Sharegpt 格式
|
||||||
|
|
||||||
@@ -251,6 +201,10 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 预训练数据集
|
||||||
|
|
||||||
|
尚不支持,请使用 [alpaca](#alpaca-格式) 格式。
|
||||||
|
|
||||||
### 偏好数据集
|
### 偏好数据集
|
||||||
|
|
||||||
- [样例数据集](dpo_zh_demo.json)
|
- [样例数据集](dpo_zh_demo.json)
|
||||||
@@ -301,6 +255,125 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### KTO 数据集
|
||||||
|
|
||||||
|
- [样例数据集](kto_en_demo.json)
|
||||||
|
|
||||||
|
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "人类指令"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "模型回答"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kto_tag": "人类反馈 [true/false](必填)"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"数据集名称": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"kto_tag": "kto_tag"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 多模态图像数据集
|
||||||
|
|
||||||
|
- [样例数据集](mllm_demo.json)
|
||||||
|
|
||||||
|
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
|
||||||
|
|
||||||
|
注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "<image>人类指令"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "模型回答"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"图像路径(必填)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"数据集名称": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"images": "images"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 多模态视频数据集
|
||||||
|
|
||||||
|
- [样例数据集](mllm_video_demo.json)
|
||||||
|
|
||||||
|
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
|
||||||
|
|
||||||
|
注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "<video>人类指令"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "模型回答"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"videos": [
|
||||||
|
"视频路径(必填)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"数据集名称": {
|
||||||
|
"file_name": "data.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"videos": "videos"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### OpenAI 格式
|
### OpenAI 格式
|
||||||
|
|
||||||
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
|
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
|
||||||
@@ -344,7 +417,3 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Sharegpt 格式中的 KTO 数据集和多模态数据集与 alpaca 格式的类似。
|
|
||||||
|
|
||||||
预训练数据集**不支持** sharegpt 格式。
|
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
||||||
_LICENSE = "gpl-3.0"
|
_LICENSE = "gpl-3.0"
|
||||||
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
||||||
|
|
||||||
|
|
||||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
@@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str):
|
def _generate_examples(self, filepath: str):
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for key, row in enumerate(f):
|
for key, row in enumerate(f):
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
conversations = []
|
conversations = []
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ import datasets
|
|||||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||||
_CITATION = ""
|
_CITATION = ""
|
||||||
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
||||||
_LICENSE = "mit"
|
_LICENSE = "mit"
|
||||||
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||||
_URLS = {
|
_URLS = {
|
||||||
"train": [
|
"train": [
|
||||||
_URL + "harmless-base/train.jsonl.gz",
|
_URL + "harmless-base/train.jsonl.gz",
|
||||||
@@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
|||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
key = 0
|
key = 0
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
chosen = data["chosen"]
|
chosen = data["chosen"]
|
||||||
|
|||||||
BIN
data/mllm_demo_data/1.mp4
Normal file
BIN
data/mllm_demo_data/1.mp4
Normal file
Binary file not shown.
BIN
data/mllm_demo_data/2.avi
Normal file
BIN
data/mllm_demo_data/2.avi
Normal file
Binary file not shown.
BIN
data/mllm_demo_data/3.mp4
Normal file
BIN
data/mllm_demo_data/3.mp4
Normal file
Binary file not shown.
@@ -20,9 +20,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
||||||
_LICENSE = "cc-by-nc-4.0"
|
_LICENSE = "cc-by-nc-4.0"
|
||||||
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
||||||
|
|
||||||
|
|
||||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||||
@@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
|||||||
|
|
||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
try:
|
try:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
72
docker/docker-cuda/Dockerfile
Normal file
72
docker/docker-cuda/Dockerfile
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# Default use the NVIDIA official image with PyTorch 2.3.0
|
||||||
|
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
|
||||||
|
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
|
||||||
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
|
# Define environments
|
||||||
|
ENV MAX_JOBS=4
|
||||||
|
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
|
||||||
|
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
|
||||||
|
# Define installation arguments
|
||||||
|
ARG INSTALL_BNB=false
|
||||||
|
ARG INSTALL_VLLM=false
|
||||||
|
ARG INSTALL_DEEPSPEED=false
|
||||||
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
|
ARG INSTALL_EETQ=false
|
||||||
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install the requirements
|
||||||
|
COPY requirements.txt /app
|
||||||
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
|
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||||
|
python -m pip install --upgrade pip && \
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Copy the rest of the application into the image
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
# Install the LLaMA Factory
|
||||||
|
RUN EXTRA_PACKAGES="metrics"; \
|
||||||
|
if [ "$INSTALL_BNB" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_VLLM" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_EETQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
|
||||||
|
fi; \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
|
# Rebuild flash attention
|
||||||
|
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||||
|
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||||
|
pip uninstall -y ninja && pip install ninja && \
|
||||||
|
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set up volumes
|
||||||
|
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||||
|
|
||||||
|
# Expose port 7860 for the LLaMA Board
|
||||||
|
ENV GRADIO_SERVER_PORT 7860
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
# Expose port 8000 for the API service
|
||||||
|
ENV API_PORT 8000
|
||||||
|
EXPOSE 8000
|
||||||
@@ -1,23 +1,30 @@
|
|||||||
services:
|
services:
|
||||||
llamafactory:
|
llamafactory:
|
||||||
build:
|
build:
|
||||||
dockerfile: Dockerfile
|
dockerfile: ./docker/docker-cuda/Dockerfile
|
||||||
context: .
|
context: ../..
|
||||||
args:
|
args:
|
||||||
INSTALL_BNB: false
|
INSTALL_BNB: false
|
||||||
INSTALL_VLLM: false
|
INSTALL_VLLM: false
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: false
|
||||||
|
INSTALL_FLASHATTN: false
|
||||||
|
INSTALL_LIGER_KERNEL: false
|
||||||
|
INSTALL_HQQ: false
|
||||||
|
INSTALL_EETQ: false
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ./hf_cache:/root/.cache/huggingface/
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ./data:/app/data
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
- ./output:/app/output
|
- ../../om_cache:/root/.cache/openmind
|
||||||
|
- ../../data:/app/data
|
||||||
|
- ../../output:/app/output
|
||||||
ports:
|
ports:
|
||||||
- "7860:7860"
|
- "7860:7860"
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
ipc: host
|
ipc: host
|
||||||
tty: true
|
tty: true
|
||||||
|
shm_size: '16gb'
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
command: bash
|
command: bash
|
||||||
deploy:
|
deploy:
|
||||||
45
docker/docker-npu/Dockerfile
Normal file
45
docker/docker-npu/Dockerfile
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
|
||||||
|
# More versions can be found at https://hub.docker.com/r/ascendai/cann/tags
|
||||||
|
# FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8
|
||||||
|
FROM ascendai/cann:8.0.rc1-910b-ubuntu22.04-py3.8
|
||||||
|
# FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8
|
||||||
|
# FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8
|
||||||
|
|
||||||
|
# Define environments
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Define installation arguments
|
||||||
|
ARG INSTALL_DEEPSPEED=false
|
||||||
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install the requirements
|
||||||
|
COPY requirements.txt /app
|
||||||
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
|
pip config set global.extra-index-url "$TORCH_INDEX" && \
|
||||||
|
python -m pip install --upgrade pip && \
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Copy the rest of the application into the image
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
# Install the LLaMA Factory
|
||||||
|
RUN EXTRA_PACKAGES="torch-npu,metrics"; \
|
||||||
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
|
fi; \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
|
# Set up volumes
|
||||||
|
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||||
|
|
||||||
|
# Expose port 7860 for the LLaMA Board
|
||||||
|
ENV GRADIO_SERVER_PORT 7860
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
# Expose port 8000 for the API service
|
||||||
|
ENV API_PORT 8000
|
||||||
|
EXPOSE 8000
|
||||||
33
docker/docker-npu/docker-compose.yml
Normal file
33
docker/docker-npu/docker-compose.yml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
services:
|
||||||
|
llamafactory:
|
||||||
|
build:
|
||||||
|
dockerfile: ./docker/docker-npu/Dockerfile
|
||||||
|
context: ../..
|
||||||
|
args:
|
||||||
|
INSTALL_DEEPSPEED: false
|
||||||
|
PIP_INDEX: https://pypi.org/simple
|
||||||
|
container_name: llamafactory
|
||||||
|
volumes:
|
||||||
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
|
- ../../data:/app/data
|
||||||
|
- ../../output:/app/output
|
||||||
|
- /usr/local/dcmi:/usr/local/dcmi
|
||||||
|
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
|
||||||
|
- /usr/local/Ascend/driver:/usr/local/Ascend/driver
|
||||||
|
- /etc/ascend_install.info:/etc/ascend_install.info
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
- "8000:8000"
|
||||||
|
ipc: host
|
||||||
|
tty: true
|
||||||
|
shm_size: '16gb'
|
||||||
|
stdin_open: true
|
||||||
|
command: bash
|
||||||
|
devices:
|
||||||
|
- /dev/davinci0
|
||||||
|
- /dev/davinci_manager
|
||||||
|
- /dev/devmm_svm
|
||||||
|
- /dev/hisi_hdc
|
||||||
|
restart: unless-stopped
|
||||||
65
docker/docker-rocm/Dockerfile
Normal file
65
docker/docker-rocm/Dockerfile
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
FROM hardandheavy/transformers-rocm:2.2.0
|
||||||
|
|
||||||
|
# Define environments
|
||||||
|
ENV MAX_JOBS=4
|
||||||
|
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
|
||||||
|
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
|
||||||
|
# Define installation arguments
|
||||||
|
ARG INSTALL_BNB=false
|
||||||
|
ARG INSTALL_VLLM=false
|
||||||
|
ARG INSTALL_DEEPSPEED=false
|
||||||
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install the requirements
|
||||||
|
COPY requirements.txt /app
|
||||||
|
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||||
|
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||||
|
python -m pip install --upgrade pip && \
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Copy the rest of the application into the image
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
# Install the LLaMA Factory
|
||||||
|
RUN EXTRA_PACKAGES="metrics"; \
|
||||||
|
if [ "$INSTALL_BNB" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_VLLM" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
|
# Rebuild flash attention
|
||||||
|
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||||
|
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||||
|
pip uninstall -y ninja && pip install ninja && \
|
||||||
|
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set up volumes
|
||||||
|
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||||
|
|
||||||
|
# Expose port 7860 for the LLaMA Board
|
||||||
|
ENV GRADIO_SERVER_PORT 7860
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
# Expose port 8000 for the API service
|
||||||
|
ENV API_PORT 8000
|
||||||
|
EXPOSE 8000
|
||||||
33
docker/docker-rocm/docker-compose.yml
Normal file
33
docker/docker-rocm/docker-compose.yml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
services:
|
||||||
|
llamafactory:
|
||||||
|
build:
|
||||||
|
dockerfile: ./docker/docker-rocm/Dockerfile
|
||||||
|
context: ../..
|
||||||
|
args:
|
||||||
|
INSTALL_BNB: false
|
||||||
|
INSTALL_VLLM: false
|
||||||
|
INSTALL_DEEPSPEED: false
|
||||||
|
INSTALL_FLASHATTN: false
|
||||||
|
INSTALL_LIGER_KERNEL: false
|
||||||
|
INSTALL_HQQ: false
|
||||||
|
PIP_INDEX: https://pypi.org/simple
|
||||||
|
container_name: llamafactory
|
||||||
|
volumes:
|
||||||
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
|
- ../../data:/app/data
|
||||||
|
- ../../output:/app/output
|
||||||
|
- ../../saves:/app/saves
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
- "8000:8000"
|
||||||
|
ipc: host
|
||||||
|
tty: true
|
||||||
|
shm_size: '16gb'
|
||||||
|
stdin_open: true
|
||||||
|
command: bash
|
||||||
|
devices:
|
||||||
|
- /dev/kfd:/dev/kfd
|
||||||
|
- /dev/dri:/dev/dri
|
||||||
|
restart: unless-stopped
|
||||||
@@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
|||||||
df = pd.read_csv(filepath, header=None)
|
df = pd.read_csv(filepath, header=None)
|
||||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||||
|
|
||||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
yield from enumerate(df.to_dict(orient="records"))
|
||||||
yield i, instance
|
|
||||||
|
|||||||
@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||||
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### DPO/ORPO/SimPO Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Multimodal DPO/ORPO/SimPO Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Reward Modeling
|
#### Reward Modeling
|
||||||
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
|||||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO Training
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### KTO Training
|
#### KTO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -82,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||||
@@ -94,10 +101,10 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
|
|||||||
|
|
||||||
### QLoRA Fine-Tuning
|
### QLoRA Fine-Tuning
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
|
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||||
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Multimodal Supervised Fine-Tuning
|
||||||
|
|
||||||
|
```bash
|
||||||
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
|||||||
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Full-Parameter Fine-Tuning using Adam-mini
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### LoRA+ Fine-Tuning
|
#### LoRA+ Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||||
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### DPO/ORPO/SimPO 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 多模态 DPO/ORPO/SimPO 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 奖励模型训练
|
#### 奖励模型训练
|
||||||
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
|||||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO 训练
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### KTO 训练
|
#### KTO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -82,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### 多机指令监督微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
@@ -94,10 +101,10 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
|
|||||||
|
|
||||||
### QLoRA 微调
|
### QLoRA 微调
|
||||||
|
|
||||||
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
|
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||||
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 多模态指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
|||||||
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 使用 Adam-mini 进行全参数训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### LoRA+ 微调
|
#### LoRA+ 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
39
examples/extras/adam_mini/qwen2_full_sft.yaml
Normal file
39
examples/extras/adam_mini/qwen2_full_sft.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen2-1.5B-Instruct
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
use_adam_mini: true
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
template: qwen
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2-1_5b/full/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
@@ -6,14 +6,16 @@ stage: sft
|
|||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
use_badam: true
|
use_badam: true
|
||||||
|
badam_mode: layer
|
||||||
badam_switch_mode: ascending
|
badam_switch_mode: ascending
|
||||||
badam_switch_interval: 50
|
badam_switch_interval: 50
|
||||||
badam_verbose: 2
|
badam_verbose: 2
|
||||||
|
# deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -28,11 +30,10 @@ overwrite_output_dir: true
|
|||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-5
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
pure_bf16: true
|
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -11,7 +11,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ galore_scale: 2.0
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -29,11 +29,12 @@ overwrite_output_dir: true
|
|||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-5
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
pure_bf16: true
|
pure_bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
|
|
||||||
python scripts/llama_pro.py \
|
python scripts/llama_pro.py \
|
||||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||||
--output_dir models/llama3-8b-instruct-pro \
|
--output_dir models/llama3-8b-pro \
|
||||||
--num_expand 8
|
--num_expand 8
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: models/llama3-8b-instruct-pro
|
model_name_or_path: models/llama3-8b-pro
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -12,13 +12,13 @@ use_llama_pro: true
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b-instruct-pro/freeze/sft
|
output_dir: saves/llama3-8b-pro/freeze/sft
|
||||||
logging_steps: 10
|
logging_steps: 10
|
||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
@@ -31,7 +31,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ loraplus_lr_ratio: 16.0
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ mixture_of_depths: convert
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -26,7 +26,7 @@ overwrite_output_dir: true
|
|||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
optim: paged_adamw_8bit
|
optim: paged_adamw_8bit
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-5
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
5
examples/extras/pissa/init.sh
Normal file
5
examples/extras/pissa/init.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python scripts/pissa_init.py \
|
||||||
|
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||||
|
--output_dir models/llama3-8b-pissa
|
||||||
@@ -7,13 +7,13 @@ do_train: true
|
|||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
pissa_init: true
|
pissa_init: true
|
||||||
pissa_iter: 4
|
pissa_iter: 16
|
||||||
pissa_convert: true
|
pissa_convert: true
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -32,7 +32,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
2
examples/inference/llava1_5.yaml
Normal file
2
examples/inference/llava1_5.yaml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||||
|
template: llava
|
||||||
2
examples/inference/qwen2_vl.yaml
Normal file
2
examples/inference/qwen2_vl.yaml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
template: qwen2_vl
|
||||||
13
examples/merge_lora/qwen2vl_lora_sft.yaml
Normal file
13
examples/merge_lora/qwen2vl_lora_sft.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||||
|
|
||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
||||||
|
template: qwen2_vl
|
||||||
|
finetuning_type: lora
|
||||||
|
|
||||||
|
### export
|
||||||
|
export_dir: models/qwen2_vl_lora_sft
|
||||||
|
export_size: 2
|
||||||
|
export_device: cpu
|
||||||
|
export_legacy_format: false
|
||||||
@@ -7,9 +7,9 @@ do_predict: true
|
|||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
eval_dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ deepspeed: examples/deepspeed/ds_z3_config.json
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -25,11 +25,11 @@ overwrite_output_dir: true
|
|||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-5
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
39
examples/train_full/qwen2vl_full_sft.yaml
Normal file
39
examples/train_full/qwen2vl_full_sft.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: mllm_demo,identity
|
||||||
|
template: qwen2_vl
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2_vl-7b/full/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
@@ -7,12 +7,12 @@ do_train: true
|
|||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
pref_beta: 0.1
|
pref_beta: 0.1
|
||||||
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
|
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: dpo_en_demo
|
dataset: dpo_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -31,7 +31,7 @@ learning_rate: 5.0e-6
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ adapter_name_or_path: saves/llama3-8b/lora/sft
|
|||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
task: mmlu
|
task: mmlu_test # choices: [mmlu_test, ceval_validation, cmmlu_test]
|
||||||
split: test
|
|
||||||
template: fewshot
|
template: fewshot
|
||||||
lang: en
|
lang: en
|
||||||
n_shot: 5
|
n_shot: 5
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ pref_beta: 0.1
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: kto_en_demo
|
dataset: kto_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +30,7 @@ learning_rate: 5.0e-6
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-5
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### generate
|
### generate
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ do_predict: true
|
|||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
eval_dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ lora_target: all
|
|||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: c4_demo
|
dataset: c4_demo
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b/lora/sft
|
output_dir: saves/llama3-8b/lora/pretrain
|
||||||
logging_steps: 10
|
logging_steps: 10
|
||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
@@ -28,7 +28,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: dpo_en_demo
|
dataset: dpo_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -25,11 +25,11 @@ overwrite_output_dir: true
|
|||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
learning_rate: 1.0e-5
|
learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ deepspeed: examples/deepspeed/ds_z0_config.json
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ deepspeed: examples/deepspeed/ds_z3_config.json
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||||
visual_inputs: true
|
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -10,8 +9,8 @@ lora_target: all
|
|||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: mllm_demo
|
dataset: mllm_demo
|
||||||
template: vicuna
|
template: llava
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +29,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
41
examples/train_lora/qwen2vl_lora_dpo.yaml
Normal file
41
examples/train_lora/qwen2vl_lora_dpo.yaml
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: dpo
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_target: all
|
||||||
|
pref_beta: 0.1
|
||||||
|
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: rlhf_v
|
||||||
|
template: qwen2_vl
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2_vl-7b/lora/dpo
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 5.0e-6
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
39
examples/train_lora/qwen2vl_lora_sft.yaml
Normal file
39
examples/train_lora/qwen2vl_lora_sft.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: mllm_demo,identity # video: mllm_video_demo
|
||||||
|
template: qwen2_vl
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2_vl-7b/lora/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
@@ -10,7 +10,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
quantization_bit: 4
|
quantization_bit: 4
|
||||||
|
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@@ -11,7 +12,7 @@ lora_target: all
|
|||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 1024
|
cutoff_len: 2048
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
@@ -30,7 +31,7 @@ learning_rate: 1.0e-4
|
|||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
transformers>=4.41.2
|
transformers>=4.41.2,<=4.46.1
|
||||||
datasets>=2.16.0
|
datasets>=2.16.0,<=3.1.0
|
||||||
accelerate>=0.30.1
|
accelerate>=0.34.0,<=1.0.1
|
||||||
peft>=0.11.1
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6
|
trl>=0.8.6,<=0.9.6
|
||||||
gradio>=4.0.0
|
gradio>=4.0.0,<5.0.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
@@ -18,3 +18,6 @@ matplotlib>=3.7.0
|
|||||||
fire
|
fire
|
||||||
packaging
|
packaging
|
||||||
pyyaml
|
pyyaml
|
||||||
|
numpy<2.0.0
|
||||||
|
av
|
||||||
|
tyro<0.9.0
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||||
@@ -27,7 +26,7 @@ from llamafactory.chat import ChatModel
|
|||||||
def calculate_flops(
|
def calculate_flops(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
seq_length: int = 256,
|
seq_length: int = 512,
|
||||||
flash_attn: str = "auto",
|
flash_attn: str = "auto",
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -36,9 +35,11 @@ def calculate_flops(
|
|||||||
"""
|
"""
|
||||||
with get_accelerator().device(0):
|
with get_accelerator().device(0):
|
||||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
|
||||||
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||||
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
flops, macs, params = get_model_profile(
|
||||||
|
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
|
||||||
|
)
|
||||||
print("FLOPs:", flops)
|
print("FLOPs:", flops)
|
||||||
print("MACs:", macs)
|
print("MACs:", macs)
|
||||||
print("Params:", params)
|
print("Params:", params)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the imoneoi's OpenChat library.
|
# This code is inspired by the imoneoi's OpenChat library.
|
||||||
@@ -25,7 +24,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llamafactory.data import get_dataset
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_train_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
@@ -39,15 +38,17 @@ def calculate_lr(
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||||
stage: Literal["pt", "sft"] = "sft",
|
stage: Literal["pt", "sft"] = "sft",
|
||||||
dataset: str = "alpaca_en",
|
dataset: str = "alpaca_en_demo",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
|
||||||
|
packing: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||||
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
Usage:
|
||||||
|
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
|
||||||
"""
|
"""
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
@@ -57,19 +58,22 @@ def calculate_lr(
|
|||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
template=template,
|
template=template,
|
||||||
cutoff_len=cutoff_len,
|
cutoff_len=cutoff_len,
|
||||||
|
packing=packing,
|
||||||
output_dir="dummy_dir",
|
output_dir="dummy_dir",
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
|
do_train=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
valid_tokens, total_tokens = 0, 0
|
valid_tokens, total_tokens = 0, 0
|
||||||
@@ -81,7 +85,7 @@ def calculate_lr(
|
|||||||
valid_ratio = valid_tokens / total_tokens
|
valid_ratio = valid_tokens / total_tokens
|
||||||
batch_valid_len = batch_max_len * valid_ratio
|
batch_valid_len = batch_max_len * valid_ratio
|
||||||
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
||||||
lr = lr / 6.0 if is_mistral else lr
|
lr = lr / 6.0 if is_mistral_or_gemma else lr
|
||||||
print(
|
print(
|
||||||
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||||
lr, valid_ratio * 100, batch_valid_len
|
lr, valid_ratio * 100, batch_valid_len
|
||||||
|
|||||||
163
scripts/cal_mfu.py
Normal file
163
scripts/cal_mfu.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from llamafactory.train.tuner import run_exp
|
||||||
|
|
||||||
|
|
||||||
|
BASE = 2 # gemm (add + mul)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_model_flops(
|
||||||
|
model_name_or_path: str,
|
||||||
|
total_batch_size: int,
|
||||||
|
seq_length: int,
|
||||||
|
include_backward: bool = True,
|
||||||
|
include_recompute: bool = False,
|
||||||
|
include_flashattn: bool = False,
|
||||||
|
) -> int:
|
||||||
|
r"""
|
||||||
|
Calculates the FLOPs of model per forward/backward pass.
|
||||||
|
"""
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
hidden_size = getattr(config, "hidden_size", None)
|
||||||
|
vocab_size = getattr(config, "vocab_size", None)
|
||||||
|
intermediate_size = getattr(config, "intermediate_size", None)
|
||||||
|
num_attention_heads = getattr(config, "num_attention_heads", None)
|
||||||
|
num_key_value_heads = getattr(config, "num_key_value_heads", None)
|
||||||
|
num_hidden_layers = getattr(config, "num_hidden_layers", None)
|
||||||
|
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
|
||||||
|
|
||||||
|
# mlp module
|
||||||
|
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
|
||||||
|
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
|
||||||
|
|
||||||
|
# attn projector module
|
||||||
|
q_flops_per_token = BASE * hidden_size * hidden_size
|
||||||
|
o_flops_per_token = BASE * hidden_size * hidden_size
|
||||||
|
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
||||||
|
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
||||||
|
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
|
||||||
|
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
|
||||||
|
|
||||||
|
# attn sdpa module
|
||||||
|
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
|
||||||
|
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
|
||||||
|
|
||||||
|
# embedding module
|
||||||
|
embedding_flops_per_token = hidden_size * vocab_size
|
||||||
|
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
|
||||||
|
if tie_word_embeddings is False:
|
||||||
|
embedding_flops *= 2
|
||||||
|
|
||||||
|
non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
|
||||||
|
non_embedding_coeff, embedding_coeff = 1, 1
|
||||||
|
if include_backward:
|
||||||
|
non_embedding_coeff += 2
|
||||||
|
embedding_coeff += 2
|
||||||
|
|
||||||
|
if include_recompute:
|
||||||
|
non_embedding_coeff += 1
|
||||||
|
|
||||||
|
total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
|
||||||
|
|
||||||
|
if include_flashattn:
|
||||||
|
total_flops += sdpa_flops
|
||||||
|
|
||||||
|
return total_flops
|
||||||
|
|
||||||
|
|
||||||
|
def compute_device_flops(world_size: int) -> float:
|
||||||
|
r"""
|
||||||
|
Calculates the FLOPs of the device capability per second.
|
||||||
|
"""
|
||||||
|
device_name = torch.cuda.get_device_name()
|
||||||
|
if "H100" in device_name or "H800" in device_name:
|
||||||
|
return 989 * 1e12 * world_size
|
||||||
|
elif "A100" in device_name or "A800" in device_name:
|
||||||
|
return 312 * 1e12 * world_size
|
||||||
|
elif "V100" in device_name:
|
||||||
|
return 125 * 1e12 * world_size
|
||||||
|
elif "4090" in device_name:
|
||||||
|
return 98 * 1e12 * world_size
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Device not supported: {device_name}.")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_mfu(
|
||||||
|
model_name_or_path: str,
|
||||||
|
batch_size: int = 1,
|
||||||
|
seq_length: int = 1024,
|
||||||
|
num_steps: int = 100,
|
||||||
|
finetuning_type: str = "lora",
|
||||||
|
flash_attn: str = "auto",
|
||||||
|
deepspeed_stage: int = 0,
|
||||||
|
disable_gc: bool = False,
|
||||||
|
liger_kernel: bool = False,
|
||||||
|
unsloth_gc: bool = False,
|
||||||
|
) -> float:
|
||||||
|
r"""
|
||||||
|
Calculates MFU for given model and hyper-params.
|
||||||
|
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
||||||
|
"""
|
||||||
|
args = {
|
||||||
|
"model_name_or_path": model_name_or_path,
|
||||||
|
"flash_attn": flash_attn,
|
||||||
|
"disable_gradient_checkpointing": disable_gc,
|
||||||
|
"enable_liger_kernel": liger_kernel,
|
||||||
|
"use_unsloth_gc": unsloth_gc,
|
||||||
|
"stage": "pt",
|
||||||
|
"do_train": True,
|
||||||
|
"finetuning_type": finetuning_type,
|
||||||
|
"dataset": "c4_demo",
|
||||||
|
"cutoff_len": seq_length,
|
||||||
|
"output_dir": os.path.join("saves", "test_mfu"),
|
||||||
|
"logging_strategy": "no",
|
||||||
|
"save_strategy": "no",
|
||||||
|
"save_only_model": True,
|
||||||
|
"overwrite_output_dir": True,
|
||||||
|
"per_device_train_batch_size": batch_size,
|
||||||
|
"max_steps": num_steps,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
if deepspeed_stage in [2, 3]:
|
||||||
|
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||||
|
|
||||||
|
run_exp(args)
|
||||||
|
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||||
|
result = json.load(f)
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
else:
|
||||||
|
world_size = 1
|
||||||
|
|
||||||
|
total_batch_size = batch_size * world_size
|
||||||
|
mfu_value = (
|
||||||
|
result["train_steps_per_second"]
|
||||||
|
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||||
|
/ compute_device_flops(world_size)
|
||||||
|
)
|
||||||
|
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(calculate_mfu)
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -23,7 +22,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llamafactory.data import get_dataset
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_train_args
|
||||||
from llamafactory.model import load_model, load_tokenizer
|
from llamafactory.model import load_model, load_tokenizer
|
||||||
@@ -55,12 +54,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
return super().__call__(chosen_features)
|
return super().__call__(chosen_features)
|
||||||
|
|
||||||
|
|
||||||
def cal_ppl(
|
def calculate_ppl(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
save_name: str,
|
save_name: str,
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
stage: Literal["pt", "sft", "rm"] = "sft",
|
stage: Literal["pt", "sft", "rm"] = "sft",
|
||||||
dataset: str = "alpaca_en",
|
dataset: str = "alpaca_en_demo",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 1024,
|
cutoff_len: int = 1024,
|
||||||
@@ -69,7 +68,7 @@ def cal_ppl(
|
|||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates the ppl on the dataset of the pre-trained models.
|
Calculates the ppl on the dataset of the pre-trained models.
|
||||||
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
|
||||||
"""
|
"""
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
@@ -83,11 +82,13 @@ def cal_ppl(
|
|||||||
train_on_prompt=train_on_prompt,
|
train_on_prompt=train_on_prompt,
|
||||||
output_dir="dummy_dir",
|
output_dir="dummy_dir",
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
|
do_train=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
|
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
@@ -98,7 +99,7 @@ def cal_ppl(
|
|||||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
@@ -123,9 +124,9 @@ def cal_ppl(
|
|||||||
with open(save_name, "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(perplexities, f, indent=2)
|
json.dump(perplexities, f, indent=2)
|
||||||
|
|
||||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||||
print("Perplexities have been saved at {}.".format(save_name))
|
print(f"Perplexities have been saved at {save_name}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(cal_ppl)
|
fire.Fire(calculate_ppl)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -18,21 +17,21 @@ from collections import defaultdict
|
|||||||
import fire
|
import fire
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from llamafactory.data import get_dataset
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_train_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
def length_cdf(
|
def length_cdf(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
dataset: str = "alpaca_en",
|
dataset: str = "alpaca_en_demo",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
interval: int = 1000,
|
interval: int = 1000,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates the distribution of the input lengths in the dataset.
|
Calculates the distribution of the input lengths in the dataset.
|
||||||
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
|
||||||
"""
|
"""
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
@@ -44,10 +43,12 @@ def length_cdf(
|
|||||||
cutoff_len=1_000_000,
|
cutoff_len=1_000_000,
|
||||||
output_dir="dummy_dir",
|
output_dir="dummy_dir",
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
|
do_train=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||||
|
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
|
||||||
total_num = len(trainset)
|
total_num = len(trainset)
|
||||||
length_dict = defaultdict(int)
|
length_dict = defaultdict(int)
|
||||||
for sample in tqdm(trainset["input_ids"]):
|
for sample in tqdm(trainset["input_ids"]):
|
||||||
@@ -59,7 +60,7 @@ def length_cdf(
|
|||||||
for length, count in length_tuples:
|
for length, count in length_tuples:
|
||||||
count_accu += count
|
count_accu += count
|
||||||
prob_accu += count / total_num * 100
|
prob_accu += count / total_num * 100
|
||||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||||
@@ -19,7 +18,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -40,15 +39,15 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
|
||||||
|
|
||||||
|
|
||||||
def block_expansion(
|
def block_expansion(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
num_expand: int,
|
num_expand: int,
|
||||||
shard_size: Optional[str] = "2GB",
|
shard_size: str = "2GB",
|
||||||
save_safetensors: Optional[bool] = False,
|
save_safetensors: bool = True,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
||||||
@@ -76,27 +75,27 @@ def block_expansion(
|
|||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
if num_layers % num_expand != 0:
|
if num_layers % num_expand != 0:
|
||||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||||
|
|
||||||
split = num_layers // num_expand
|
split = num_layers // num_expand
|
||||||
layer_cnt = 0
|
layer_cnt = 0
|
||||||
output_state_dict = OrderedDict()
|
output_state_dict = OrderedDict()
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||||
|
|
||||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} copied from layer {i}")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
if (i + 1) % split == 0:
|
if (i + 1) % split == 0:
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
if "down_proj" in key or "o_proj" in key:
|
if "down_proj" in key or "o_proj" in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||||
else:
|
else:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||||
|
|
||||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} expanded from layer {i}")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
@@ -113,17 +112,17 @@ def block_expansion(
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("finetuning_type: freeze")
|
print("finetuning_type: freeze")
|
||||||
print("freeze_trainable_layers: {}".format(num_expand))
|
print(f"freeze_trainable_layers: {num_expand}")
|
||||||
print("use_llama_pro: true")
|
print("use_llama_pro: true")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -16,7 +15,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str):
|
def save_config(input_dir: str, output_dir: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||||
@@ -82,11 +81,14 @@ def save_config(input_dir: str, output_dir: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(
|
||||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
input_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
shard_size: str = "2GB",
|
||||||
|
save_safetensors: bool = True,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -16,7 +15,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
elif "lm_head" in key:
|
elif "lm_head" in key:
|
||||||
llama2_state_dict[key] = value
|
llama2_state_dict[key] = value
|
||||||
else:
|
else:
|
||||||
raise KeyError("Unable to process key {}".format(key))
|
raise KeyError(f"Unable to process key {key}")
|
||||||
|
|
||||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||||
@@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
return str(torch_dtype).replace("torch.", "")
|
return str(torch_dtype).replace("torch.", "")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||||
@@ -135,11 +134,14 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_qwen(
|
def llamafy_qwen(
|
||||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
input_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
shard_size: str = "2GB",
|
||||||
|
save_safetensors: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Converts the Qwen models in the same format as LLaMA2.
|
Converts the Qwen models in the same format as LLaMA2.
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
@@ -36,15 +35,19 @@ def quantize_loftq(
|
|||||||
lora_alpha: int = None,
|
lora_alpha: int = None,
|
||||||
lora_rank: int = 16,
|
lora_rank: int = 16,
|
||||||
lora_dropout: float = 0,
|
lora_dropout: float = 0,
|
||||||
lora_target: str = "q_proj,v_proj",
|
lora_target: tuple = ("q_proj", "v_proj"),
|
||||||
save_safetensors: bool = True,
|
save_safetensors: bool = True,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||||
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
|
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||||
"""
|
"""
|
||||||
|
if isinstance(lora_target, str):
|
||||||
|
lora_target = [name.strip() for name in lora_target.split(",")]
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
|
|
||||||
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
@@ -52,7 +55,7 @@ def quantize_loftq(
|
|||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
target_modules=lora_target,
|
||||||
init_lora_weights="loftq",
|
init_lora_weights="loftq",
|
||||||
loftq_config=loftq_config,
|
loftq_config=loftq_config,
|
||||||
)
|
)
|
||||||
@@ -63,22 +66,22 @@ def quantize_loftq(
|
|||||||
loftq_dir = os.path.join(output_dir, "loftq_init")
|
loftq_dir = os.path.join(output_dir, "loftq_init")
|
||||||
|
|
||||||
# Save LoftQ model
|
# Save LoftQ model
|
||||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(loftq_dir))
|
print(f"Adapter weights saved in {loftq_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
print(f"adapter_name_or_path: {loftq_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("quantization_bit: {}".format(loftq_bits))
|
print(f"quantization_bit: {loftq_bits}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
@@ -31,26 +30,30 @@ if TYPE_CHECKING:
|
|||||||
def quantize_pissa(
|
def quantize_pissa(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
pissa_iter: int = 4,
|
pissa_iter: int = 16,
|
||||||
lora_alpha: int = None,
|
lora_alpha: int = None,
|
||||||
lora_rank: int = 16,
|
lora_rank: int = 16,
|
||||||
lora_dropout: float = 0,
|
lora_dropout: float = 0,
|
||||||
lora_target: str = "q_proj,v_proj",
|
lora_target: tuple = ("q_proj", "v_proj"),
|
||||||
save_safetensors: bool = True,
|
save_safetensors: bool = True,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
|
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
|
||||||
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
|
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||||
"""
|
"""
|
||||||
|
if isinstance(lora_target, str):
|
||||||
|
lora_target = [name.strip() for name in lora_target.split(",")]
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
target_modules=lora_target,
|
||||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init PiSSA model
|
# Init PiSSA model
|
||||||
@@ -58,19 +61,20 @@ def quantize_pissa(
|
|||||||
pissa_dir = os.path.join(output_dir, "pissa_init")
|
pissa_dir = os.path.join(output_dir, "pissa_init")
|
||||||
|
|
||||||
# Save PiSSA model
|
# Save PiSSA model
|
||||||
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(pissa_dir))
|
print(f"Adapter weights saved in {pissa_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
print(f"adapter_name_or_path: {pissa_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("pissa_init: false")
|
print("pissa_init: false")
|
||||||
print("pissa_convert: true")
|
print("pissa_convert: true")
|
||||||
|
|||||||
65
scripts/test_image.py
Normal file
65
scripts/test_image.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||||
|
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||||
|
)
|
||||||
|
messages = []
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Output the color and number of each box."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 1:", result.choices[0].message.content)
|
||||||
|
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What kind of flower is this?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 2:", result.choices[0].message.content)
|
||||||
|
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|||||||
38
setup.py
38
setup.py
@@ -14,40 +14,54 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version() -> str:
|
||||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||||
(version,) = re.findall(pattern, file_content)
|
(version,) = re.findall(pattern, file_content)
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
def get_requires():
|
def get_requires() -> List[str]:
|
||||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
with open("requirements.txt", encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def get_console_scripts() -> List[str]:
|
||||||
|
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
|
||||||
|
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
|
||||||
|
console_scripts.append("lmf = llamafactory.cli:main")
|
||||||
|
|
||||||
|
return console_scripts
|
||||||
|
|
||||||
|
|
||||||
extra_require = {
|
extra_require = {
|
||||||
"torch": ["torch>=1.13.1"],
|
"torch": ["torch>=1.13.1"],
|
||||||
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
"deepspeed": ["deepspeed>=0.10.0"],
|
"deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
|
||||||
|
"liger-kernel": ["liger-kernel"],
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
"vllm": ["vllm>=0.4.3"],
|
"hqq": ["hqq"],
|
||||||
"galore": ["galore-torch"],
|
"eetq": ["eetq"],
|
||||||
"badam": ["badam"],
|
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
|
||||||
"awq": ["autoawq"],
|
"awq": ["autoawq"],
|
||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
|
"vllm": ["vllm>=0.4.3,<0.6.4"],
|
||||||
|
"galore": ["galore-torch"],
|
||||||
|
"badam": ["badam>=1.2.1"],
|
||||||
|
"adam-mini": ["adam-mini"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"dev": ["ruff", "pytest"],
|
"openmind": ["openmind"],
|
||||||
|
"dev": ["pre-commit", "ruff", "pytest"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +72,7 @@ def main():
|
|||||||
author="hiyouga",
|
author="hiyouga",
|
||||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||||
description="Easy-to-use LLM fine-tuning framework",
|
description="Easy-to-use LLM fine-tuning framework",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||||
license="Apache 2.0 License",
|
license="Apache 2.0 License",
|
||||||
@@ -68,7 +82,7 @@ def main():
|
|||||||
python_requires=">=3.8.0",
|
python_requires=">=3.8.0",
|
||||||
install_requires=get_requires(),
|
install_requires=get_requires(),
|
||||||
extras_require=extra_require,
|
extras_require=extra_require,
|
||||||
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
|
entry_points={"console_scripts": get_console_scripts()},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
|
|||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,36 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
r"""
|
||||||
|
Efficient fine-tuning of large language models.
|
||||||
|
|
||||||
from .cli import VERSION
|
Level:
|
||||||
|
api, webui > chat, eval, train > data, model > hparams > extras
|
||||||
|
|
||||||
|
Dependency graph:
|
||||||
|
main:
|
||||||
|
transformers>=4.41.2,<=4.46.1
|
||||||
|
datasets>=2.16.0,<=3.1.0
|
||||||
|
accelerate>=0.34.0,<=1.0.1
|
||||||
|
peft>=0.11.1,<=0.12.0
|
||||||
|
trl>=0.8.6,<=0.9.6
|
||||||
|
attention:
|
||||||
|
transformers>=4.42.4 (gemma+fa2)
|
||||||
|
longlora:
|
||||||
|
transformers>=4.41.2,<=4.46.1
|
||||||
|
packing:
|
||||||
|
transformers>=4.41.2,<=4.46.1
|
||||||
|
|
||||||
|
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||||
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
|
Force check imports: FORCE_CHECK_IMPORTS=1
|
||||||
|
Force using torchrun: FORCE_TORCHRUN=1
|
||||||
|
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||||
|
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||||
|
Use openmind: USE_OPENMIND_HUB=1
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .extras.env import VERSION
|
||||||
|
|
||||||
|
|
||||||
__version__ = VERSION
|
__version__ = VERSION
|
||||||
|
|||||||
@@ -12,8 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@@ -50,14 +52,24 @@ if is_uvicorn_available():
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
|
async def sweeper() -> None:
|
||||||
|
while True:
|
||||||
|
torch_gc()
|
||||||
|
await asyncio.sleep(300)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
||||||
|
if chat_model.engine_type == "huggingface":
|
||||||
|
asyncio.create_task(sweeper())
|
||||||
|
|
||||||
yield
|
yield
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
app = FastAPI(lifespan=lifespan)
|
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
||||||
|
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=["*"],
|
||||||
@@ -65,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
api_key = os.environ.get("API_KEY")
|
api_key = os.getenv("API_KEY")
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||||
@@ -79,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
dependencies=[Depends(verify_api_key)],
|
dependencies=[Depends(verify_api_key)],
|
||||||
)
|
)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||||
return ModelList(data=[model_card])
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
@@ -116,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
def run_api() -> None:
|
def run_api() -> None:
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
|||||||
@@ -16,11 +16,12 @@ import base64
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
from .common import dictify, jsonify
|
from .common import dictify, jsonify
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
@@ -51,13 +52,12 @@ if is_requests_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
|
from ..data.mm_plugin import ImageInput
|
||||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
ROLE_MAPPING = {
|
ROLE_MAPPING = {
|
||||||
Role.USER: DataRole.USER.value,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
@@ -69,8 +69,8 @@ ROLE_MAPPING = {
|
|||||||
|
|
||||||
def _process_request(
|
def _process_request(
|
||||||
request: "ChatCompletionRequest",
|
request: "ChatCompletionRequest",
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
@@ -84,7 +84,7 @@ def _process_request(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
input_messages = []
|
input_messages = []
|
||||||
image = None
|
images = []
|
||||||
for i, message in enumerate(request.messages):
|
for i, message in enumerate(request.messages):
|
||||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
@@ -93,7 +93,7 @@ def _process_request(
|
|||||||
|
|
||||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
{"name": tool_call.function.name, "argument": tool_call.function.arguments}
|
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
|
||||||
for tool_call in message.tool_calls
|
for tool_call in message.tool_calls
|
||||||
]
|
]
|
||||||
content = json.dumps(tool_calls, ensure_ascii=False)
|
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||||
@@ -104,15 +104,14 @@ def _process_request(
|
|||||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
||||||
else:
|
else:
|
||||||
image_url = input_item.image_url.url
|
image_url = input_item.image_url.url
|
||||||
if image_url.startswith("data:image"): # base64 image
|
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
|
||||||
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
|
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
|
||||||
image_path = io.BytesIO(image_data)
|
|
||||||
elif os.path.isfile(image_url): # local file
|
elif os.path.isfile(image_url): # local file
|
||||||
image_path = open(image_url, "rb")
|
image_stream = open(image_url, "rb")
|
||||||
else: # web uri
|
else: # web uri
|
||||||
image_path = requests.get(image_url, stream=True).raw
|
image_stream = requests.get(image_url, stream=True).raw
|
||||||
|
|
||||||
image = Image.open(image_path).convert("RGB")
|
images.append(Image.open(image_stream).convert("RGB"))
|
||||||
else:
|
else:
|
||||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||||
|
|
||||||
@@ -125,7 +124,7 @@ def _process_request(
|
|||||||
else:
|
else:
|
||||||
tools = None
|
tools = None
|
||||||
|
|
||||||
return input_messages, system, tools, image
|
return input_messages, system, tools, images or None
|
||||||
|
|
||||||
|
|
||||||
def _create_stream_chat_completion_chunk(
|
def _create_stream_chat_completion_chunk(
|
||||||
@@ -143,13 +142,13 @@ def _create_stream_chat_completion_chunk(
|
|||||||
async def create_chat_completion_response(
|
async def create_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> "ChatCompletionResponse":
|
) -> "ChatCompletionResponse":
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
responses = await chat_model.achat(
|
responses = await chat_model.achat(
|
||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@@ -170,7 +169,7 @@ async def create_chat_completion_response(
|
|||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tool in result:
|
for tool in result:
|
||||||
function = Function(name=tool[0], arguments=tool[1])
|
function = Function(name=tool[0], arguments=tool[1])
|
||||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||||
|
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
finish_reason = Finish.TOOL
|
finish_reason = Finish.TOOL
|
||||||
@@ -194,8 +193,8 @@ async def create_chat_completion_response(
|
|||||||
async def create_stream_chat_completion_response(
|
async def create_stream_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
if tools:
|
if tools:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
@@ -209,7 +208,7 @@ async def create_stream_chat_completion_response(
|
|||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@@ -230,8 +229,9 @@ async def create_stream_chat_completion_response(
|
|||||||
async def create_score_evaluation_response(
|
async def create_score_evaluation_response(
|
||||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||||
) -> "ScoreEvaluationResponse":
|
) -> "ScoreEvaluationResponse":
|
||||||
|
score_id = f"scoreval-{uuid.uuid4().hex}"
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
tools: Optional[List[FunctionAvailable]] = None
|
tools: Optional[List[FunctionAvailable]] = None
|
||||||
do_sample: bool = True
|
do_sample: Optional[bool] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
n: int = 1
|
n: int = 1
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
|
|
||||||
from ..data import Template
|
from ..data import Template
|
||||||
|
from ..data.mm_plugin import ImageInput, VideoInput
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -35,6 +35,12 @@ class Response:
|
|||||||
|
|
||||||
|
|
||||||
class BaseEngine(ABC):
|
class BaseEngine(ABC):
|
||||||
|
r"""
|
||||||
|
Base class for inference engine of chat models.
|
||||||
|
|
||||||
|
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||||
|
"""
|
||||||
|
|
||||||
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||||
tokenizer: "PreTrainedTokenizer"
|
tokenizer: "PreTrainedTokenizer"
|
||||||
can_generate: bool
|
can_generate: bool
|
||||||
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Initializes an inference engine.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
@@ -56,9 +66,14 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]: ...
|
) -> List["Response"]:
|
||||||
|
r"""
|
||||||
|
Gets a list of responses of the chat model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
@@ -66,13 +81,22 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]: ...
|
) -> AsyncGenerator[str, None]:
|
||||||
|
r"""
|
||||||
|
Gets the response token-by-token of the chat model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_scores(
|
async def get_scores(
|
||||||
self,
|
self,
|
||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[float]: ...
|
) -> List[float]:
|
||||||
|
r"""
|
||||||
|
Gets a list of scores of the reward model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
|
|
||||||
@@ -26,8 +27,7 @@ from .vllm_engine import VllmEngine
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
from ..data.mm_plugin import ImageInput, VideoInput
|
||||||
|
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
@@ -37,14 +37,23 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
|||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
r"""
|
||||||
|
General class for chat models. Backed by huggingface or vllm engines.
|
||||||
|
|
||||||
|
Supports both sync and async methods.
|
||||||
|
Sync methods: chat(), stream_chat() and get_scores().
|
||||||
|
Async methods: achat(), astream_chat() and aget_scores().
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
|
self.engine_type = model_args.infer_backend
|
||||||
if model_args.infer_backend == "huggingface":
|
if model_args.infer_backend == "huggingface":
|
||||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
elif model_args.infer_backend == "vllm":
|
elif model_args.infer_backend == "vllm":
|
||||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||||
|
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
@@ -55,10 +64,16 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
r"""
|
||||||
|
Gets a list of responses of the chat model.
|
||||||
|
"""
|
||||||
|
task = asyncio.run_coroutine_threadsafe(
|
||||||
|
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
|
||||||
|
)
|
||||||
return task.result()
|
return task.result()
|
||||||
|
|
||||||
async def achat(
|
async def achat(
|
||||||
@@ -66,20 +81,28 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
r"""
|
||||||
|
Asynchronously gets a list of responses of the chat model.
|
||||||
|
"""
|
||||||
|
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
|
||||||
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
r"""
|
||||||
|
Gets the response token-by-token of the chat model.
|
||||||
|
"""
|
||||||
|
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||||
@@ -92,10 +115,14 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
r"""
|
||||||
|
Asynchronously gets the response token-by-token of the chat model.
|
||||||
|
"""
|
||||||
|
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
|
||||||
yield new_token
|
yield new_token
|
||||||
|
|
||||||
def get_scores(
|
def get_scores(
|
||||||
@@ -103,6 +130,9 @@ class ChatModel:
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
|
r"""
|
||||||
|
Gets a list of scores of the reward model.
|
||||||
|
"""
|
||||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||||
return task.result()
|
return task.result()
|
||||||
|
|
||||||
@@ -111,17 +141,18 @@ class ChatModel:
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
|
r"""
|
||||||
|
Asynchronously gets a list of scores of the reward model.
|
||||||
|
"""
|
||||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def run_chat() -> None:
|
def run_chat() -> None:
|
||||||
try:
|
if os.name != "nt":
|
||||||
import platform
|
try:
|
||||||
|
|
||||||
if platform.system() != "Windows":
|
|
||||||
import readline # noqa: F401
|
import readline # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Install `readline` for a better experience.")
|
print("Install `readline` for a better experience.")
|
||||||
|
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
messages = []
|
messages = []
|
||||||
|
|||||||
@@ -20,25 +20,26 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
|
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
|
||||||
from trl import PreTrainedModelWrapper
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
from ..data import Template
|
from ..data import Template
|
||||||
|
from ..data.mm_plugin import ImageInput, VideoInput
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceEngine(BaseEngine):
|
class HuggingfaceEngine(BaseEngine):
|
||||||
@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
self.tokenizer = tokenizer_module["tokenizer"]
|
self.tokenizer = tokenizer_module["tokenizer"]
|
||||||
self.processor = tokenizer_module["processor"]
|
self.processor = tokenizer_module["processor"]
|
||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||||
self.model = load_model(
|
self.model = load_model(
|
||||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
) # must after fixing tokenizer to resize vocab
|
) # must after fixing tokenizer to resize vocab
|
||||||
@@ -62,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
try:
|
try:
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.warning("There is no current event loop, creating a new one.")
|
logger.warning_once("There is no current event loop, creating a new one.")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
if (
|
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||||
processor is not None
|
if images is not None:
|
||||||
and image is not None
|
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||||
and not hasattr(processor, "image_seq_length")
|
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||||
and template.image_token not in messages[0]["content"]
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
): # llava-like models
|
|
||||||
messages[0]["content"] = template.image_token + messages[0]["content"]
|
|
||||||
|
|
||||||
|
if videos is not None:
|
||||||
|
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||||
|
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
||||||
|
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||||
|
|
||||||
|
messages = template.mm_plugin.process_messages(
|
||||||
|
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||||
|
)
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or generating_args["default_system"]
|
system = system or generating_args["default_system"]
|
||||||
pixel_values = None
|
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||||
prompt_ids, _ = template.encode_oneturn(
|
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
|
||||||
)
|
)
|
||||||
if processor is not None and image is not None: # add image features
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
||||||
batch_feature = image_processor(image, return_tensors="pt")
|
|
||||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
|
||||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
|
||||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
|
||||||
|
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
||||||
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
logger.warning("Stop parameter is not supported in Huggingface engine yet.")
|
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||||
|
|
||||||
generating_args = generating_args.copy()
|
generating_args = generating_args.copy()
|
||||||
generating_args.update(
|
generating_args.update(
|
||||||
@@ -164,8 +164,14 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
||||||
gen_kwargs["pixel_values"] = pixel_values
|
for key, value in mm_inputs.items():
|
||||||
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||||
|
value = torch.stack(value) # assume they have same sizes
|
||||||
|
elif not isinstance(value, torch.Tensor):
|
||||||
|
value = torch.tensor(value)
|
||||||
|
|
||||||
|
gen_kwargs[key] = value.to(model.device)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
@@ -180,11 +186,22 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
input_kwargs,
|
||||||
)
|
)
|
||||||
generate_output = model.generate(**gen_kwargs)
|
generate_output = model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
@@ -215,11 +232,22 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Callable[[], str]:
|
) -> Callable[[], str]:
|
||||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
input_kwargs,
|
||||||
)
|
)
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
@@ -242,37 +270,28 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||||
device = getattr(model.pretrained_model, "device", "cuda")
|
device = getattr(model.pretrained_model, "device", "cuda")
|
||||||
inputs = tokenizer(
|
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
||||||
batch_input,
|
batch_input,
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_special_tokens=True,
|
add_special_tokens=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) == "chatglm":
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for i in range(input_ids.size(0)):
|
|
||||||
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
|
||||||
scores.append(values[i, end_index].nan_to_num().item())
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
@override
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@@ -288,19 +307,22 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
|
videos,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||||
|
|
||||||
|
@override
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@@ -316,7 +338,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
|
videos,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
@@ -328,6 +351,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@override
|
||||||
async def get_scores(
|
async def get_scores(
|
||||||
self,
|
self,
|
||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
|
|||||||
@@ -13,35 +13,37 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
|
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
|
from ..model.model_utils.quantization import QuantizationMethod
|
||||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if is_pillow_available():
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import Image as ImageObject
|
||||||
|
|
||||||
|
|
||||||
if is_vllm_available():
|
if is_vllm_available():
|
||||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
if is_vllm_version_greater_than_0_5():
|
|
||||||
from vllm.multimodal.image import ImagePixelData
|
|
||||||
else:
|
|
||||||
from vllm.sequence import MultiModalData
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
from ..data.mm_plugin import ImageInput, VideoInput
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
|
||||||
|
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VllmEngine(BaseEngine):
|
class VllmEngine(BaseEngine):
|
||||||
@@ -53,13 +55,18 @@ class VllmEngine(BaseEngine):
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
config = load_config(model_args) # may download model from ms hub
|
config = load_config(model_args) # may download model from ms hub
|
||||||
|
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||||
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
|
quant_method = quantization_config.get("quant_method", "")
|
||||||
|
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
||||||
|
model_args.infer_dtype = "float16"
|
||||||
|
|
||||||
self.can_generate = finetuning_args.stage == "sft"
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
self.tokenizer = tokenizer_module["tokenizer"]
|
self.tokenizer = tokenizer_module["tokenizer"]
|
||||||
self.processor = tokenizer_module["processor"]
|
self.processor = tokenizer_module["processor"]
|
||||||
self.tokenizer.padding_side = "left"
|
self.tokenizer.padding_side = "left"
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
engine_args = {
|
engine_args = {
|
||||||
@@ -76,20 +83,14 @@ class VllmEngine(BaseEngine):
|
|||||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||||
}
|
}
|
||||||
|
if isinstance(model_args.vllm_config, dict):
|
||||||
|
engine_args.update(model_args.vllm_config)
|
||||||
|
|
||||||
if model_args.visual_inputs:
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
image_size = config.vision_config.image_size
|
import vllm.model_executor.models.llava
|
||||||
patch_size = config.vision_config.patch_size
|
|
||||||
self.image_feature_size = (image_size // patch_size) ** 2
|
|
||||||
engine_args["image_input_type"] = "pixel_values"
|
|
||||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
|
|
||||||
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
|
||||||
engine_args["image_feature_size"] = self.image_feature_size
|
|
||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
|
||||||
import vllm.model_executor.models.llava
|
|
||||||
|
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
|
|
||||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
@@ -102,38 +103,28 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncIterator["RequestOutput"]:
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
|
if images is not None:
|
||||||
|
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||||
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
|
|
||||||
if (
|
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
|
||||||
self.processor is not None
|
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
|
||||||
and image is not None
|
|
||||||
and not hasattr(self.processor, "image_seq_length")
|
|
||||||
and self.template.image_token not in messages[0]["content"]
|
|
||||||
): # llava-like models (TODO: paligemma models)
|
|
||||||
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
|
|
||||||
|
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
|
||||||
system = system or self.generating_args["default_system"]
|
|
||||||
prompt_ids, _ = self.template.encode_oneturn(
|
|
||||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.processor is not None and image is not None: # add image features
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
|
||||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
|
||||||
if is_vllm_version_greater_than_0_5():
|
|
||||||
multi_modal_data = ImagePixelData(image=pixel_values)
|
|
||||||
else: # TODO: remove vllm 0.4.3 support
|
|
||||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
image_str = self.template.mm_plugin.image_token or ""
|
||||||
|
|
||||||
|
paired_messages = [
|
||||||
|
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
|
||||||
|
for message in messages
|
||||||
|
] + [{"role": "assistant", "content": ""}]
|
||||||
|
system = system or self.generating_args["default_system"]
|
||||||
|
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
|
||||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||||
@@ -144,6 +135,9 @@ class VllmEngine(BaseEngine):
|
|||||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
|
if length_penalty is not None:
|
||||||
|
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
||||||
|
|
||||||
if "max_new_tokens" in self.generating_args:
|
if "max_new_tokens" in self.generating_args:
|
||||||
max_tokens = self.generating_args["max_new_tokens"]
|
max_tokens = self.generating_args["max_new_tokens"]
|
||||||
elif "max_length" in self.generating_args:
|
elif "max_length" in self.generating_args:
|
||||||
@@ -167,32 +161,47 @@ class VllmEngine(BaseEngine):
|
|||||||
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
||||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||||
use_beam_search=use_beam_search,
|
|
||||||
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if images is not None: # add image features
|
||||||
|
image_data = []
|
||||||
|
for image in images:
|
||||||
|
if not isinstance(image, (str, ImageObject)):
|
||||||
|
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||||
|
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
|
image_data.append(image)
|
||||||
|
|
||||||
|
multi_modal_data = {"image": image_data}
|
||||||
|
else:
|
||||||
|
multi_modal_data = None
|
||||||
|
|
||||||
result_generator = self.model.generate(
|
result_generator = self.model.generate(
|
||||||
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=self.lora_request,
|
lora_request=self.lora_request,
|
||||||
)
|
)
|
||||||
return result_generator
|
return result_generator
|
||||||
|
|
||||||
|
@override
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
final_output = None
|
final_output = None
|
||||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||||
async for request_output in generator:
|
async for request_output in generator:
|
||||||
final_output = request_output
|
final_output = request_output
|
||||||
|
|
||||||
@@ -209,21 +218,24 @@ class VllmEngine(BaseEngine):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@override
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||||
async for result in generator:
|
async for result in generator:
|
||||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||||
generated_text = result.outputs[0].text
|
generated_text = result.outputs[0].text
|
||||||
yield delta_text
|
yield delta_text
|
||||||
|
|
||||||
|
@override
|
||||||
async def get_scores(
|
async def get_scores(
|
||||||
self,
|
self,
|
||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ from . import launcher
|
|||||||
from .api.app import run_api
|
from .api.app import run_api
|
||||||
from .chat.chat_model import run_chat
|
from .chat.chat_model import run_chat
|
||||||
from .eval.evaluator import run_eval
|
from .eval.evaluator import run_eval
|
||||||
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.logging import get_logger
|
|
||||||
from .extras.misc import get_device_count
|
from .extras.misc import get_device_count
|
||||||
from .train.tuner import export_model, run_exp
|
from .train.tuner import export_model, run_exp
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
from .webui.interface import run_web_demo, run_web_ui
|
||||||
@@ -47,7 +47,7 @@ USAGE = (
|
|||||||
WELCOME = (
|
WELCOME = (
|
||||||
"-" * 58
|
"-" * 58
|
||||||
+ "\n"
|
+ "\n"
|
||||||
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||||
+ " " * (21 - len(VERSION))
|
+ " " * (21 - len(VERSION))
|
||||||
+ "|\n|"
|
+ "|\n|"
|
||||||
+ " " * 56
|
+ " " * 56
|
||||||
@@ -56,7 +56,7 @@ WELCOME = (
|
|||||||
+ "-" * 58
|
+ "-" * 58
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@@ -74,7 +74,7 @@ class Command(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
command = sys.argv.pop(1)
|
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
||||||
if command == Command.API:
|
if command == Command.API:
|
||||||
run_api()
|
run_api()
|
||||||
elif command == Command.CHAT:
|
elif command == Command.CHAT:
|
||||||
@@ -86,26 +86,28 @@ def main():
|
|||||||
elif command == Command.EXPORT:
|
elif command == Command.EXPORT:
|
||||||
export_model()
|
export_model()
|
||||||
elif command == Command.TRAIN:
|
elif command == Command.TRAIN:
|
||||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||||
if force_torchrun or get_device_count() > 1:
|
if force_torchrun or get_device_count() > 1:
|
||||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||||
subprocess.run(
|
process = subprocess.run(
|
||||||
(
|
(
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
).format(
|
)
|
||||||
nnodes=os.environ.get("NNODES", "1"),
|
.format(
|
||||||
node_rank=os.environ.get("RANK", "0"),
|
nnodes=os.getenv("NNODES", "1"),
|
||||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
node_rank=os.getenv("NODE_RANK", "0"),
|
||||||
|
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
file_name=launcher.__file__,
|
file_name=launcher.__file__,
|
||||||
args=" ".join(sys.argv[1:]),
|
args=" ".join(sys.argv[1:]),
|
||||||
),
|
)
|
||||||
shell=True,
|
.split()
|
||||||
)
|
)
|
||||||
|
sys.exit(process.returncode)
|
||||||
else:
|
else:
|
||||||
run_exp()
|
run_exp()
|
||||||
elif command == Command.WEBDEMO:
|
elif command == Command.WEBDEMO:
|
||||||
@@ -117,4 +119,4 @@ def main():
|
|||||||
elif command == Command.HELP:
|
elif command == Command.HELP:
|
||||||
print(USAGE)
|
print(USAGE)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown command: {}".format(command))
|
raise NotImplementedError(f"Unknown command: {command}.")
|
||||||
|
|||||||
@@ -12,7 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
from .collator import (
|
||||||
|
KTODataCollatorWithPadding,
|
||||||
|
MultiModalDataCollatorForSeq2Seq,
|
||||||
|
PairwiseDataCollatorWithPadding,
|
||||||
|
SFTDataCollatorWith4DAttentionMask,
|
||||||
|
)
|
||||||
from .data_utils import Role, split_dataset
|
from .data_utils import Role, split_dataset
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||||
@@ -20,7 +25,9 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"KTODataCollatorWithPadding",
|
"KTODataCollatorWithPadding",
|
||||||
|
"MultiModalDataCollatorForSeq2Seq",
|
||||||
"PairwiseDataCollatorWithPadding",
|
"PairwiseDataCollatorWithPadding",
|
||||||
|
"SFTDataCollatorWith4DAttentionMask",
|
||||||
"Role",
|
"Role",
|
||||||
"split_dataset",
|
"split_dataset",
|
||||||
"get_dataset",
|
"get_dataset",
|
||||||
|
|||||||
@@ -14,11 +14,9 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from datasets import Features
|
from ..extras import logging
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
|
|
||||||
|
|
||||||
@@ -27,88 +25,123 @@ if TYPE_CHECKING:
|
|||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
from .mm_plugin import ImageInput, VideoInput
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
def _convert_images(
|
||||||
|
images: Union["ImageInput", Sequence["ImageInput"]],
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Optional[List["ImageInput"]]:
|
||||||
r"""
|
r"""
|
||||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||||
"""
|
"""
|
||||||
outputs = []
|
if not isinstance(images, list):
|
||||||
if dataset_attr.load_from in ["script", "file"]:
|
images = [images]
|
||||||
for image in images:
|
elif len(images) == 0:
|
||||||
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
return None
|
||||||
outputs.append(os.path.join(data_args.dataset_dir, image))
|
else:
|
||||||
else:
|
images = images[:]
|
||||||
outputs.append(image)
|
|
||||||
|
|
||||||
return outputs
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
|
for i in range(len(images)):
|
||||||
|
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
|
||||||
|
images[i] = os.path.join(data_args.image_dir, images[i])
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_videos(
|
||||||
|
videos: Union["VideoInput", Sequence["VideoInput"]],
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Optional[List["VideoInput"]]:
|
||||||
|
r"""
|
||||||
|
Optionally concatenates video path to dataset dir when loading from local disk.
|
||||||
|
"""
|
||||||
|
if not isinstance(videos, list):
|
||||||
|
videos = [videos]
|
||||||
|
elif len(videos) == 0:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
videos = videos[:]
|
||||||
|
|
||||||
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
|
for i in range(len(videos)):
|
||||||
|
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
|
||||||
|
videos[i] = os.path.join(data_args.image_dir, videos[i])
|
||||||
|
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
def convert_alpaca(
|
def convert_alpaca(
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
example: Dict[str, Any],
|
||||||
) -> Dict[str, List[Any]]:
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
r"""
|
r"""
|
||||||
Converts alpaca format dataset to the standard format.
|
Converts alpaca format dataset to the standard format.
|
||||||
"""
|
"""
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
prompt = []
|
||||||
|
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
|
||||||
|
for old_prompt, old_response in example[dataset_attr.history]:
|
||||||
|
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||||
|
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||||
|
|
||||||
|
query = []
|
||||||
|
if dataset_attr.prompt and example[dataset_attr.prompt]:
|
||||||
|
query.append(example[dataset_attr.prompt])
|
||||||
|
|
||||||
|
if dataset_attr.query and example[dataset_attr.query]:
|
||||||
|
query.append(example[dataset_attr.query])
|
||||||
|
|
||||||
|
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
||||||
|
|
||||||
|
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
||||||
|
if example[dataset_attr.kto_tag]:
|
||||||
|
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
else:
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||||
|
elif (
|
||||||
|
dataset_attr.ranking
|
||||||
|
and isinstance(example[dataset_attr.chosen], str)
|
||||||
|
and isinstance(example[dataset_attr.rejected], str)
|
||||||
|
): # pairwise example
|
||||||
|
response = [
|
||||||
|
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
|
||||||
|
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
|
||||||
|
]
|
||||||
|
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
||||||
|
else: # unsupervised
|
||||||
|
response = []
|
||||||
|
|
||||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
for i in range(len(examples[dataset_attr.prompt])):
|
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
prompt = []
|
output = {
|
||||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
"_prompt": prompt,
|
||||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
"_response": response,
|
||||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
"_system": example[dataset_attr.system] if dataset_attr.system else "",
|
||||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||||
|
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||||
content = []
|
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
}
|
||||||
content.append(examples[dataset_attr.prompt][i])
|
return output
|
||||||
|
|
||||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
|
||||||
content.append(examples[dataset_attr.query][i])
|
|
||||||
|
|
||||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
|
||||||
|
|
||||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
|
||||||
if examples[dataset_attr.kto_tag][i]:
|
|
||||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
|
||||||
else:
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
|
||||||
elif (
|
|
||||||
dataset_attr.ranking
|
|
||||||
and isinstance(examples[dataset_attr.chosen][i], str)
|
|
||||||
and isinstance(examples[dataset_attr.rejected][i], str)
|
|
||||||
): # pairwise example
|
|
||||||
response = [
|
|
||||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
|
||||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
|
||||||
]
|
|
||||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
|
||||||
else: # unsupervised
|
|
||||||
response = []
|
|
||||||
|
|
||||||
outputs["prompt"].append(prompt)
|
|
||||||
outputs["response"].append(response)
|
|
||||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
|
||||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
|
||||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def convert_sharegpt(
|
def convert_sharegpt(
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
example: Dict[str, Any],
|
||||||
) -> Dict[str, List[Any]]:
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
r"""
|
r"""
|
||||||
Converts sharegpt format dataset to the standard format.
|
Converts sharegpt format dataset to the standard format.
|
||||||
"""
|
"""
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
|
||||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
dataset_attr.user_tag: Role.USER.value,
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
@@ -119,74 +152,79 @@ def convert_sharegpt(
|
|||||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||||
accept_tags = (odd_tags, even_tags)
|
accept_tags = (odd_tags, even_tags)
|
||||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
messages = example[dataset_attr.messages]
|
||||||
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
if (
|
||||||
system = messages[0][dataset_attr.content_tag]
|
dataset_attr.system_tag
|
||||||
messages = messages[1:]
|
and len(messages) != 0
|
||||||
else:
|
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
||||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
):
|
||||||
|
system = messages[0][dataset_attr.content_tag]
|
||||||
|
messages = messages[1:]
|
||||||
|
else:
|
||||||
|
system = example[dataset_attr.system] if dataset_attr.system else ""
|
||||||
|
|
||||||
if len(messages) == 0:
|
aligned_messages = []
|
||||||
continue
|
broken_data = False
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
aligned_messages = []
|
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||||
broken_data = False
|
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||||
for turn_idx, message in enumerate(messages):
|
|
||||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
|
||||||
logger.warning("Invalid role tag in {}.".format(messages))
|
|
||||||
broken_data = True
|
|
||||||
|
|
||||||
aligned_messages.append(
|
|
||||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
|
||||||
)
|
|
||||||
|
|
||||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
|
||||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
|
||||||
):
|
|
||||||
logger.warning("Invalid message count in {}.".format(messages))
|
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
aligned_messages.append(
|
||||||
prompt = aligned_messages[:-1]
|
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||||
response = aligned_messages[-1:]
|
)
|
||||||
if examples[dataset_attr.kto_tag][i]:
|
|
||||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
|
||||||
else:
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
|
||||||
elif (
|
|
||||||
dataset_attr.ranking
|
|
||||||
and isinstance(examples[dataset_attr.chosen][i], dict)
|
|
||||||
and isinstance(examples[dataset_attr.rejected][i], dict)
|
|
||||||
): # pairwise example
|
|
||||||
chosen = examples[dataset_attr.chosen][i]
|
|
||||||
rejected = examples[dataset_attr.rejected][i]
|
|
||||||
if (
|
|
||||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
|
||||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
|
||||||
):
|
|
||||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
|
||||||
broken_data = True
|
|
||||||
|
|
||||||
prompt = aligned_messages
|
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||||
response = [
|
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
):
|
||||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||||
]
|
broken_data = True
|
||||||
else: # normal example
|
|
||||||
prompt = aligned_messages[:-1]
|
|
||||||
response = aligned_messages[-1:]
|
|
||||||
|
|
||||||
if broken_data:
|
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||||
logger.warning("Skipping this abnormal example.")
|
prompt = aligned_messages[:-1]
|
||||||
continue
|
response = aligned_messages[-1:]
|
||||||
|
if example[dataset_attr.kto_tag]:
|
||||||
|
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
else:
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||||
|
elif (
|
||||||
|
dataset_attr.ranking
|
||||||
|
and isinstance(example[dataset_attr.chosen], dict)
|
||||||
|
and isinstance(example[dataset_attr.rejected], dict)
|
||||||
|
): # pairwise example
|
||||||
|
chosen = example[dataset_attr.chosen]
|
||||||
|
rejected = example[dataset_attr.rejected]
|
||||||
|
if (
|
||||||
|
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
|
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
|
):
|
||||||
|
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||||
|
broken_data = True
|
||||||
|
|
||||||
outputs["prompt"].append(prompt)
|
prompt = aligned_messages
|
||||||
outputs["response"].append(response)
|
response = [
|
||||||
outputs["system"].append(system)
|
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
]
|
||||||
|
else: # normal example
|
||||||
|
prompt = aligned_messages[:-1]
|
||||||
|
response = aligned_messages[-1:]
|
||||||
|
|
||||||
return outputs
|
if broken_data:
|
||||||
|
logger.warning_rank0("Skipping this abnormal example.")
|
||||||
|
prompt, response = [], []
|
||||||
|
|
||||||
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
|
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
|
output = {
|
||||||
|
"_prompt": prompt,
|
||||||
|
"_response": response,
|
||||||
|
"_system": system,
|
||||||
|
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||||
|
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||||
|
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def align_dataset(
|
def align_dataset(
|
||||||
@@ -197,11 +235,12 @@ def align_dataset(
|
|||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
r"""
|
r"""
|
||||||
Aligned dataset:
|
Aligned dataset:
|
||||||
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||||
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||||
system: "..."
|
_system: "..."
|
||||||
tools: "...",
|
_tools: "...",
|
||||||
images: [],
|
_images: [],
|
||||||
|
_videos: [],
|
||||||
"""
|
"""
|
||||||
if dataset_attr.formatting == "alpaca":
|
if dataset_attr.formatting == "alpaca":
|
||||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
@@ -209,19 +248,6 @@ def align_dataset(
|
|||||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
features = Features.from_dict(
|
|
||||||
{
|
|
||||||
"prompt": [
|
|
||||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
|
||||||
],
|
|
||||||
"response": [
|
|
||||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
|
||||||
],
|
|
||||||
"system": {"dtype": "string", "_type": "Value"},
|
|
||||||
"tools": {"dtype": "string", "_type": "Value"},
|
|
||||||
"images": [{"_type": "Image"}],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
@@ -232,8 +258,7 @@ def align_dataset(
|
|||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
convert_func,
|
convert_func,
|
||||||
batched=True,
|
batched=False,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
features=features,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
||||||
|
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -13,19 +16,120 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||||
|
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
```python
|
||||||
|
# input
|
||||||
|
[[1, 1, 2, 2, 2, 0]]
|
||||||
|
# output
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[o, x, x, x, x, x],
|
||||||
|
[o, o, x, x, x, x],
|
||||||
|
[x, x, o, x, x, x],
|
||||||
|
[x, x, o, o, x, x],
|
||||||
|
[x, x, o, o, o, x],
|
||||||
|
[x, x, x, x, x, x],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
```
|
||||||
|
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||||
|
"""
|
||||||
|
bsz, seq_len = attention_mask_with_indices.size()
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||||
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
|
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
||||||
|
# Create a block-diagonal mask.
|
||||||
|
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
||||||
|
# Use the lower triangular mask to zero out the upper triangular part
|
||||||
|
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
||||||
|
# Invert the attention mask.
|
||||||
|
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
||||||
|
return attention_mask_4d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator that supports VLMs.
|
||||||
|
|
||||||
|
Features should contain input_ids, attention_mask, labels and images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
template: Optional["Template"] = None
|
||||||
|
processor: Optional["ProcessorMixin"] = None
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
|
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
|
||||||
|
for feature in features:
|
||||||
|
images = feature.pop("images", None) or []
|
||||||
|
videos = feature.pop("videos", None) or []
|
||||||
|
batch_images.extend(images)
|
||||||
|
batch_videos.extend(videos)
|
||||||
|
batch_imglens.append(len(images))
|
||||||
|
batch_vidlens.append(len(videos))
|
||||||
|
batch_input_ids.append(feature["input_ids"])
|
||||||
|
|
||||||
|
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||||
|
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
|
||||||
|
)
|
||||||
|
if "token_type_ids" in mm_inputs:
|
||||||
|
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||||
|
for i, feature in enumerate(features):
|
||||||
|
feature["token_type_ids"] = token_type_ids[i]
|
||||||
|
|
||||||
|
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||||
|
features.update(mm_inputs)
|
||||||
|
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||||
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator for 4d attention mask.
|
||||||
|
"""
|
||||||
|
|
||||||
|
block_diag_attn: bool = False
|
||||||
|
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||||
|
compute_dtype: "torch.dtype" = torch.float32
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
|
features = super().__call__(features)
|
||||||
|
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||||
|
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for pairwise data.
|
Data collator for pairwise data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Pads batched data to the longest sequence in the batch.
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
@@ -36,28 +140,24 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
for key in ("chosen", "rejected"):
|
for key in ("chosen", "rejected"):
|
||||||
for feature in features:
|
for feature in features:
|
||||||
target_feature = {
|
target_feature = {
|
||||||
"input_ids": feature["{}_input_ids".format(key)],
|
"input_ids": feature[f"{key}_input_ids"],
|
||||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
"attention_mask": feature[f"{key}_attention_mask"],
|
||||||
"labels": feature["{}_labels".format(key)],
|
"labels": feature[f"{key}_labels"],
|
||||||
|
"images": feature["images"],
|
||||||
|
"videos": feature["videos"],
|
||||||
}
|
}
|
||||||
if "pixel_values" in feature:
|
|
||||||
target_feature["pixel_values"] = feature["pixel_values"]
|
|
||||||
|
|
||||||
if "{}_token_type_ids".format(key) in feature:
|
|
||||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
|
||||||
|
|
||||||
concatenated_features.append(target_feature)
|
concatenated_features.append(target_feature)
|
||||||
|
|
||||||
return super().__call__(concatenated_features)
|
return super().__call__(concatenated_features)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for KTO data.
|
Data collator for KTO data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
target_features = []
|
target_features = []
|
||||||
kl_features = []
|
kl_features = []
|
||||||
kto_tags = []
|
kto_tags = []
|
||||||
@@ -66,19 +166,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
"input_ids": feature["input_ids"],
|
"input_ids": feature["input_ids"],
|
||||||
"attention_mask": feature["attention_mask"],
|
"attention_mask": feature["attention_mask"],
|
||||||
"labels": feature["labels"],
|
"labels": feature["labels"],
|
||||||
|
"images": feature["images"],
|
||||||
|
"videos": feature["videos"],
|
||||||
}
|
}
|
||||||
kl_feature = {
|
kl_feature = {
|
||||||
"input_ids": feature["kl_input_ids"],
|
"input_ids": feature["kl_input_ids"],
|
||||||
"attention_mask": feature["kl_attention_mask"],
|
"attention_mask": feature["kl_attention_mask"],
|
||||||
"labels": feature["kl_labels"],
|
"labels": feature["kl_labels"],
|
||||||
|
"images": feature["images"],
|
||||||
|
"videos": feature["videos"],
|
||||||
}
|
}
|
||||||
if "pixel_values" in feature:
|
|
||||||
target_feature["pixel_values"] = feature["pixel_values"]
|
|
||||||
|
|
||||||
if "token_type_ids" in feature:
|
|
||||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
|
||||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
|
||||||
|
|
||||||
target_features.append(target_feature)
|
target_features.append(target_feature)
|
||||||
kl_features.append(kl_feature)
|
kl_features.append(kl_feature)
|
||||||
kto_tags.append(feature["kto_tags"])
|
kto_tags.append(feature["kto_tags"])
|
||||||
@@ -88,7 +185,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||||
batch["kl_labels"] = kl_batch["labels"]
|
batch["kl_labels"] = kl_batch["labels"]
|
||||||
if "token_type_ids" in batch:
|
if "token_type_ids" in kl_batch:
|
||||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||||
|
|
||||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||||
|
|||||||
@@ -13,21 +13,23 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets
|
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import Seq2SeqTrainingArguments
|
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@@ -39,54 +41,52 @@ class Role(str, Enum):
|
|||||||
OBSERVATION = "observation"
|
OBSERVATION = "observation"
|
||||||
|
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
class DatasetModule(TypedDict):
|
||||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||||
max_target_len = max(max_target_len, reserved_label_len)
|
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||||
max_source_len = max_len - min(max_target_len, target_len)
|
|
||||||
return max_source_len, max_target_len
|
|
||||||
|
|
||||||
|
|
||||||
def merge_dataset(
|
def merge_dataset(
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
r"""
|
||||||
|
Merges multiple datasets to a unified dataset.
|
||||||
|
"""
|
||||||
if len(all_datasets) == 1:
|
if len(all_datasets) == 1:
|
||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
|
||||||
return concatenate_datasets(all_datasets)
|
return concatenate_datasets(all_datasets)
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
|
||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
probabilities=data_args.interleave_probs,
|
probabilities=data_args.interleave_probs,
|
||||||
seed=training_args.seed,
|
seed=seed,
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> "DatasetDict":
|
||||||
if training_args.do_train:
|
r"""
|
||||||
if data_args.val_size > 1e-6: # Split the dataset
|
Splits the dataset and returns a dataset dict containing train set and validation set.
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
Supports both map dataset and iterable dataset.
|
||||||
val_set = dataset.take(int(data_args.val_size))
|
"""
|
||||||
train_set = dataset.skip(int(data_args.val_size))
|
if data_args.streaming:
|
||||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||||
else:
|
val_set = dataset.take(int(data_args.val_size))
|
||||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
train_set = dataset.skip(int(data_args.val_size))
|
||||||
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
return DatasetDict({"train": train_set, "validation": val_set})
|
||||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
else:
|
||||||
else:
|
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||||
if data_args.streaming:
|
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
|
||||||
return {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
return {"eval_dataset": dataset}
|
|
||||||
|
|||||||
@@ -16,108 +16,36 @@ import json
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from .data_utils import SLOTS
|
||||||
|
from .tool_utils import get_tool_utils
|
||||||
|
|
||||||
|
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
if TYPE_CHECKING:
|
||||||
|
from .tool_utils import FunctionCall
|
||||||
|
|
||||||
DEFAULT_TOOL_PROMPT = (
|
|
||||||
"You have access to the following tools:\n{tool_text}"
|
|
||||||
"Use the following format if using a tool:\n"
|
|
||||||
"```\n"
|
|
||||||
"Action: tool name (one of [{tool_names}]).\n"
|
|
||||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
|
||||||
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
|
||||||
"```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
GLM4_TOOL_PROMPT = (
|
|
||||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
|
||||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|
||||||
tool_text = ""
|
|
||||||
tool_names = []
|
|
||||||
for tool in tools:
|
|
||||||
param_text = ""
|
|
||||||
for name, param in tool["parameters"]["properties"].items():
|
|
||||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
|
||||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
|
||||||
items = (
|
|
||||||
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
|
||||||
)
|
|
||||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
|
||||||
name=name,
|
|
||||||
type=param.get("type", ""),
|
|
||||||
required=required,
|
|
||||||
desc=param.get("description", ""),
|
|
||||||
enum=enum,
|
|
||||||
items=items,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
|
||||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
|
||||||
)
|
|
||||||
tool_names.append(tool["name"])
|
|
||||||
|
|
||||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
|
||||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
|
||||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
|
||||||
if not action_match:
|
|
||||||
return content
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for match in action_match:
|
|
||||||
tool_name = match[0].strip()
|
|
||||||
tool_input = match[1].strip().strip('"').strip("```")
|
|
||||||
try:
|
|
||||||
arguments = json.loads(tool_input)
|
|
||||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return content
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|
||||||
tool_text = ""
|
|
||||||
for tool in tools:
|
|
||||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
|
||||||
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
|
||||||
|
|
||||||
|
|
||||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
|
||||||
if "\n" not in content:
|
|
||||||
return content
|
|
||||||
|
|
||||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
|
||||||
try:
|
|
||||||
arguments = json.loads(tool_input)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return content
|
|
||||||
|
|
||||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Optional[Literal["default", "glm4"]] = None
|
tool_format: Optional[str] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS: ...
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
r"""
|
||||||
|
Forms a list of slots according to the inputs to encode.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
r"""
|
||||||
|
Extract a list of tuples from the response message if using tools.
|
||||||
|
|
||||||
|
Each tuple consists of function name and function arguments.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -132,6 +60,7 @@ class EmptyFormatter(Formatter):
|
|||||||
if has_placeholder:
|
if has_placeholder:
|
||||||
raise ValueError("Empty formatter should not contain any placeholder.")
|
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
return self.slots
|
return self.slots
|
||||||
|
|
||||||
@@ -147,20 +76,21 @@ class StringFormatter(Formatter):
|
|||||||
if not has_placeholder:
|
if not has_placeholder:
|
||||||
raise ValueError("A placeholder is required in the string formatter.")
|
raise ValueError("A placeholder is required in the string formatter.")
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for slot in self.slots:
|
||||||
if isinstance(slot, str):
|
if isinstance(slot, str):
|
||||||
for name, value in kwargs.items():
|
for name, value in kwargs.items():
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise RuntimeError("Expected a string, got {}".format(value))
|
raise RuntimeError(f"Expected a string, got {value}")
|
||||||
|
|
||||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@@ -168,16 +98,9 @@ class StringFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
has_name, has_args = False, False
|
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
|
||||||
if "{{name}}" in slot:
|
|
||||||
has_name = True
|
|
||||||
if "{{arguments}}" in slot:
|
|
||||||
has_args = True
|
|
||||||
|
|
||||||
if not has_name or not has_args:
|
|
||||||
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
functions: List[Tuple[str, str]] = []
|
functions: List[Tuple[str, str]] = []
|
||||||
@@ -190,7 +113,7 @@ class FunctionFormatter(Formatter):
|
|||||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
functions = []
|
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for name, arguments in functions:
|
for name, arguments in functions:
|
||||||
@@ -201,7 +124,7 @@ class FunctionFormatter(Formatter):
|
|||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@@ -209,22 +132,17 @@ class FunctionFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tool_format == "default":
|
self.tool_utils = get_tool_utils(self.tool_format)
|
||||||
self._tool_formatter = default_tool_formatter
|
|
||||||
self._tool_extractor = default_tool_extractor
|
|
||||||
elif self.tool_format == "glm4":
|
|
||||||
self._tool_formatter = glm4_tool_formatter
|
|
||||||
self._tool_extractor = glm4_tool_extractor
|
|
||||||
else:
|
|
||||||
raise ValueError("Tool format was not found.")
|
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return [""]
|
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
@override
|
||||||
return self._tool_extractor(content)
|
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
return self.tool_utils.tool_extractor(content)
|
||||||
|
|||||||
@@ -12,22 +12,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import has_tokenized_data
|
from ..extras.misc import has_tokenized_data
|
||||||
from .aligner import align_dataset
|
from .aligner import align_dataset
|
||||||
from .data_utils import merge_dataset
|
from .data_utils import merge_dataset, split_dataset
|
||||||
from .parser import get_dataset_list
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .preprocess import get_preprocess_and_print_func
|
||||||
from .template import get_template_and_fix_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -35,21 +34,26 @@ if TYPE_CHECKING:
|
|||||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
|
from .data_utils import DatasetModule
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_single_dataset(
|
def _load_single_dataset(
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
r"""
|
||||||
|
Loads a single dataset and aligns it to the standard format.
|
||||||
|
"""
|
||||||
|
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||||
data_path = dataset_attr.dataset_name
|
data_path = dataset_attr.dataset_name
|
||||||
data_name = dataset_attr.subset
|
data_name = dataset_attr.subset
|
||||||
data_dir = dataset_attr.folder
|
data_dir = dataset_attr.folder
|
||||||
@@ -65,65 +69,71 @@ def load_single_dataset(
|
|||||||
if os.path.isdir(local_path): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
if data_path is None:
|
|
||||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
|
||||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
|
||||||
raise ValueError("File types should be identical.")
|
|
||||||
elif os.path.isfile(local_path): # is file
|
elif os.path.isfile(local_path): # is file
|
||||||
data_files.append(local_path)
|
data_files.append(local_path)
|
||||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("File {} not found.".format(local_path))
|
raise ValueError(f"File {local_path} not found.")
|
||||||
|
|
||||||
|
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||||
|
|
||||||
|
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
||||||
|
raise ValueError("File types should be identical.")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if dataset_attr.load_from == "ms_hub":
|
||||||
try:
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset # type: ignore
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
dataset = MsDataset.load(
|
dataset = MsDataset.load(
|
||||||
dataset_name=data_path,
|
dataset_name=data_path,
|
||||||
subset_name=data_name,
|
subset_name=data_name,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=data_args.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
use_streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
if isinstance(dataset, MsDataset):
|
if isinstance(dataset, MsDataset):
|
||||||
dataset = dataset.to_hf_dataset()
|
dataset = dataset.to_hf_dataset()
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
elif dataset_attr.load_from == "om_hub":
|
||||||
|
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||||
|
from openmind import OmDataset # type: ignore
|
||||||
|
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
|
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
||||||
|
dataset = OmDataset.load_dataset(
|
||||||
|
path=data_path,
|
||||||
|
name=data_name,
|
||||||
|
data_dir=data_dir,
|
||||||
|
data_files=data_files,
|
||||||
|
split=dataset_attr.split,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
token=model_args.om_hub_token,
|
||||||
|
streaming=data_args.streaming,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
|
||||||
kwargs = {"trust_remote_code": True}
|
|
||||||
else:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
name=data_name,
|
name=data_name,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=data_args.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=data_args.streaming,
|
||||||
**kwargs,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
|
||||||
|
|
||||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||||
target_num = dataset_attr.num_samples
|
target_num = dataset_attr.num_samples
|
||||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||||
target_num -= len(indexes)
|
target_num -= len(indexes)
|
||||||
if target_num > 0:
|
if target_num > 0:
|
||||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||||
@@ -131,7 +141,7 @@ def load_single_dataset(
|
|||||||
|
|
||||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||||
dataset = dataset.select(indexes)
|
dataset = dataset.select(indexes)
|
||||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||||
|
|
||||||
if data_args.max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
max_samples = min(data_args.max_samples, len(dataset))
|
max_samples = min(data_args.max_samples, len(dataset))
|
||||||
@@ -140,71 +150,156 @@ def load_single_dataset(
|
|||||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_merged_dataset(
|
||||||
|
dataset_names: Optional[Sequence[str]],
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
|
r"""
|
||||||
|
Gets the merged datasets in the standard format.
|
||||||
|
"""
|
||||||
|
if dataset_names is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
|
||||||
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
|
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||||
|
|
||||||
|
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_preprocessed_dataset(
|
||||||
|
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
|
is_eval: bool = False,
|
||||||
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
|
r"""
|
||||||
|
Preprocesses the dataset, including format checking and tokenization.
|
||||||
|
"""
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
|
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
||||||
|
)
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
|
desc="Running tokenizer on dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
preprocess_func,
|
||||||
|
batched=True,
|
||||||
|
batch_size=data_args.preprocessing_batch_size,
|
||||||
|
remove_columns=column_names,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_args.should_log:
|
||||||
|
try:
|
||||||
|
print("eval example:" if is_eval else "training example:")
|
||||||
|
print_function(next(iter(dataset)))
|
||||||
|
except StopIteration:
|
||||||
|
if stage == "pt":
|
||||||
|
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
|
template: "Template",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> "DatasetModule":
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
r"""
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
Gets the train dataset and optionally gets the evaluation dataset.
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
"""
|
||||||
|
|
||||||
# Load tokenized dataset
|
# Load tokenized dataset
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||||
dataset = load_from_disk(data_args.tokenized_path)
|
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||||
|
|
||||||
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
|
if "train" in dataset_dict:
|
||||||
|
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||||
|
|
||||||
|
if "validation" in dataset_dict:
|
||||||
|
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||||
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.to_iterable_dataset()
|
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||||
return dataset
|
|
||||||
|
return dataset_module
|
||||||
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||||
|
|
||||||
|
# Load and preprocess dataset
|
||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
all_datasets = []
|
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||||
for dataset_attr in get_dataset_list(data_args):
|
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
|
||||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
|
||||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
|
||||||
|
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
|
||||||
|
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
dataset = _get_preprocessed_dataset(
|
||||||
data_args, training_args, stage, template, tokenizer, processor
|
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||||
|
)
|
||||||
|
eval_dataset = _get_preprocessed_dataset(
|
||||||
|
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||||
)
|
)
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
|
||||||
desc="Running tokenizer on dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
if data_args.val_size > 1e-6:
|
||||||
|
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
|
||||||
|
else:
|
||||||
|
dataset_dict = {}
|
||||||
|
if dataset is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
|
||||||
|
dataset_dict["train"] = dataset
|
||||||
|
|
||||||
|
if eval_dataset is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
|
||||||
|
dataset_dict["validation"] = eval_dataset
|
||||||
|
|
||||||
|
dataset_dict = DatasetDict(dataset_dict)
|
||||||
|
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if training_args.should_log:
|
dataset_module = {}
|
||||||
try:
|
if "train" in dataset_dict:
|
||||||
print_function(next(iter(dataset)))
|
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||||
except StopIteration:
|
|
||||||
if stage == "pt":
|
|
||||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
|
||||||
|
|
||||||
return dataset
|
if "validation" in dataset_dict:
|
||||||
|
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||||
|
|
||||||
|
return dataset_module
|
||||||
|
|||||||
787
src/llamafactory/data/mm_plugin.py
Normal file
787
src/llamafactory/data/mm_plugin.py
Normal file
@@ -0,0 +1,787 @@
|
|||||||
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers.image_utils import get_image_size, to_numpy_array
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
|
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
|
if is_pillow_available():
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import Image as ImageObject
|
||||||
|
|
||||||
|
|
||||||
|
if is_pyav_available():
|
||||||
|
import av
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("4.45.0"):
|
||||||
|
from transformers.models.mllama.processing_mllama import (
|
||||||
|
convert_sparse_cross_attention_mask_to_dense,
|
||||||
|
get_cross_attention_token_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from av.stream import Stream
|
||||||
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
|
class EncodedImage(TypedDict):
|
||||||
|
path: Optional[str]
|
||||||
|
bytes: Optional[bytes]
|
||||||
|
|
||||||
|
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
|
||||||
|
VideoInput = str
|
||||||
|
|
||||||
|
|
||||||
|
def _get_paligemma_token_type_ids(
|
||||||
|
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
||||||
|
) -> List[List[int]]:
|
||||||
|
r"""
|
||||||
|
Gets paligemma token type ids for computing loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch_token_type_ids: shape (batch_size, sequence_length)
|
||||||
|
"""
|
||||||
|
batch_token_type_ids = []
|
||||||
|
for imglen, seqlen in zip(imglens, seqlens):
|
||||||
|
image_seqlen = imglen * getattr(processor, "image_seqlen")
|
||||||
|
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
|
||||||
|
|
||||||
|
return batch_token_type_ids
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlugin:
|
||||||
|
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
||||||
|
self.image_token = image_token
|
||||||
|
self.video_token = video_token
|
||||||
|
|
||||||
|
def _validate_input(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Validates if this model accepts the input modalities.
|
||||||
|
"""
|
||||||
|
if len(images) != 0 and self.image_token is None:
|
||||||
|
raise ValueError("This model does not support image input.")
|
||||||
|
|
||||||
|
if len(videos) != 0 and self.video_token is None:
|
||||||
|
raise ValueError("This model does not support video input.")
|
||||||
|
|
||||||
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
|
r"""
|
||||||
|
Pre-processes a single image.
|
||||||
|
"""
|
||||||
|
image_resolution: int = kwargs.get("image_resolution")
|
||||||
|
if (image.width * image.height) > image_resolution:
|
||||||
|
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
||||||
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||||
|
r"""
|
||||||
|
Computes video sample frames according to fps.
|
||||||
|
"""
|
||||||
|
video_fps: float = kwargs.get("video_fps")
|
||||||
|
video_maxlen: int = kwargs.get("video_maxlen")
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||||
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||||
|
return math.floor(sample_frames)
|
||||||
|
|
||||||
|
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
||||||
|
r"""
|
||||||
|
Regularizes images to avoid error. Including reading and pre-processing.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for image in images:
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image)
|
||||||
|
elif isinstance(image, bytes):
|
||||||
|
image = Image.open(BytesIO(image))
|
||||||
|
elif isinstance(image, dict):
|
||||||
|
if image["bytes"] is not None:
|
||||||
|
image = Image.open(BytesIO(image["bytes"]))
|
||||||
|
else:
|
||||||
|
image = Image.open(image["path"])
|
||||||
|
|
||||||
|
if not isinstance(image, ImageObject):
|
||||||
|
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
||||||
|
|
||||||
|
results.append(self._preprocess_image(image, **kwargs))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||||
|
r"""
|
||||||
|
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for video in videos:
|
||||||
|
container = av.open(video, "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
||||||
|
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
frames: List["ImageObject"] = []
|
||||||
|
container.seek(0)
|
||||||
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
if frame_idx in sample_indices:
|
||||||
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
frames = self._regularize_images(frames, **kwargs)
|
||||||
|
results.append(frames)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
r"""
|
||||||
|
Processes visual inputs.
|
||||||
|
|
||||||
|
Returns: (llava and paligemma)
|
||||||
|
pixel_values: tensor with shape (B, C, H, W)
|
||||||
|
|
||||||
|
Returns: (qwen2-vl)
|
||||||
|
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||||
|
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||||
|
|
||||||
|
It holds num_patches == torch.prod(image_grid_thw)
|
||||||
|
"""
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||||
|
input_dict = {"images": None} # default key
|
||||||
|
if len(images) != 0:
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||||
|
)
|
||||||
|
input_dict["images"] = images
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
videos = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||||
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
|
)
|
||||||
|
input_dict["videos"] = videos
|
||||||
|
|
||||||
|
mm_inputs = {}
|
||||||
|
if image_processor != video_processor:
|
||||||
|
if input_dict.get("images") is not None:
|
||||||
|
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
||||||
|
if input_dict.get("videos") is not None:
|
||||||
|
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
||||||
|
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
||||||
|
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
r"""
|
||||||
|
Pre-processes input messages before tokenization for VLMs.
|
||||||
|
"""
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def process_token_ids(
|
||||||
|
self,
|
||||||
|
input_ids: List[int],
|
||||||
|
labels: Optional[List[int]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Tuple[List[int], Optional[List[int]]]:
|
||||||
|
r"""
|
||||||
|
Pre-processes token ids after tokenization for VLMs.
|
||||||
|
"""
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return input_ids, labels
|
||||||
|
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
r"""
|
||||||
|
Builds batched multimodal inputs for VLMs.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
images: a list of image inputs, shape (num_images,)
|
||||||
|
videos: a list of video inputs, shape (num_videos,)
|
||||||
|
imglens: number of images in each sample, shape (batch_size,)
|
||||||
|
vidlens: number of videos in each sample, shape (batch_size,)
|
||||||
|
batch_ids: input ids of samples, shape (batch_size, seq_len)
|
||||||
|
processor: a processor for pre-processing images and videos
|
||||||
|
"""
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
image_seqlen = getattr(processor, "image_seqlen")
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "image_sizes" in mm_inputs:
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextVideoPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if "pixel_values_videos" in mm_inputs:
|
||||||
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||||
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", "")
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_token_ids(
|
||||||
|
self,
|
||||||
|
input_ids: List[int],
|
||||||
|
labels: Optional[List[int]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Tuple[List[int], Optional[List[int]]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_images = len(images)
|
||||||
|
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||||
|
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
|
input_ids = [image_token_id] * image_seqlen + input_ids
|
||||||
|
if labels is not None:
|
||||||
|
labels = [IGNORE_INDEX] * image_seqlen + labels
|
||||||
|
|
||||||
|
return input_ids, labels
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
patch_size = getattr(processor, "patch_size")
|
||||||
|
image_token = getattr(processor, "image_token")
|
||||||
|
image_break_token = getattr(processor, "image_break_token")
|
||||||
|
image_end_token = getattr(processor, "image_end_token")
|
||||||
|
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if image_input_sizes is None:
|
||||||
|
raise ValueError("Cannot get image input sizes.")
|
||||||
|
|
||||||
|
image_size = image_input_sizes[0][num_image_tokens]
|
||||||
|
height, width = image_size
|
||||||
|
num_height_tokens = height // patch_size
|
||||||
|
num_width_tokens = width // patch_size
|
||||||
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||||
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||||
|
replace_tokens[-1] = image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if mm_inputs.get("pixel_values"):
|
||||||
|
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||||
|
|
||||||
|
mm_inputs.pop("image_sizes", None)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
|
image = super()._preprocess_image(image, **kwargs)
|
||||||
|
if min(image.width, image.height) < 28:
|
||||||
|
width, height = max(image.width, 28), max(image.height, 28)
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
if image.width / image.height > 200:
|
||||||
|
width, height = image.height * 180, image.height
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
if image.height / image.width > 200:
|
||||||
|
width, height = image.width, image.width * 180
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||||
|
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
||||||
|
sample_frames = sample_frames // 2 * 2
|
||||||
|
return sample_frames
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
|
|
||||||
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if num_image_tokens >= len(image_grid_thw):
|
||||||
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER,
|
||||||
|
"<|vision_start|>{}<|vision_end|>".format(
|
||||||
|
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
if num_video_tokens >= len(video_grid_thw):
|
||||||
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
content = content.replace(
|
||||||
|
VIDEO_PLACEHOLDER,
|
||||||
|
"<|vision_start|>{}<|vision_end|>".format(
|
||||||
|
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
num_video_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLlavaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
num_frames = 0
|
||||||
|
has_images = "pixel_values_images" in mm_inputs
|
||||||
|
has_videos = "pixel_values_videos" in mm_inputs
|
||||||
|
if has_images or has_videos:
|
||||||
|
if has_images:
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||||
|
num_frames = 1
|
||||||
|
|
||||||
|
if has_videos:
|
||||||
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||||
|
video_seqlen = image_seqlen * num_frames
|
||||||
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||||
|
|
||||||
|
content = content.replace("{{image}}", self.image_token)
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||||
|
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
r"""
|
||||||
|
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pixel_values: tensor with shape
|
||||||
|
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||||
|
For example, (2, 1, 4, 3, 560, 560).
|
||||||
|
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||||
|
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||||
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||||
|
"""
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||||
|
return image_processor([[image] for image in images], return_tensors="pt")
|
||||||
|
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
if len(images) != len(batch_ids):
|
||||||
|
raise ValueError("Mllama only supports one image per sample.")
|
||||||
|
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
num_tiles = mm_inputs.pop("num_tiles")
|
||||||
|
image_token_id = getattr(processor, "image_token_id")
|
||||||
|
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||||
|
cross_attention_token_mask = [
|
||||||
|
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||||
|
]
|
||||||
|
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
|
||||||
|
cross_attention_token_mask,
|
||||||
|
num_tiles=num_tiles,
|
||||||
|
max_num_tiles=max_image_tiles,
|
||||||
|
length=max(len(input_ids) for input_ids in batch_ids),
|
||||||
|
)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
PLUGINS = {
|
||||||
|
"base": BasePlugin,
|
||||||
|
"llava": LlavaPlugin,
|
||||||
|
"llava_next": LlavaNextPlugin,
|
||||||
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
|
"paligemma": PaliGemmaPlugin,
|
||||||
|
"pixtral": PixtralPlugin,
|
||||||
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
|
"video_llava": VideoLlavaPlugin,
|
||||||
|
"mllama": MllamaPlugin,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_mm_plugin(
|
||||||
|
name: str,
|
||||||
|
image_token: Optional[str] = None,
|
||||||
|
video_token: Optional[str] = None,
|
||||||
|
) -> "BasePlugin":
|
||||||
|
plugin_class = PLUGINS.get(name, None)
|
||||||
|
if plugin_class is None:
|
||||||
|
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||||
|
|
||||||
|
return plugin_class(image_token, video_token)
|
||||||
@@ -15,14 +15,12 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..extras.constants import DATA_CONFIG
|
from ..extras.constants import DATA_CONFIG
|
||||||
from ..extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope, use_openmind
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..hparams import DataArguments
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -31,31 +29,33 @@ class DatasetAttr:
|
|||||||
Dataset attributes.
|
Dataset attributes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
""" basic configs """
|
# basic configs
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
""" extra configs """
|
# extra configs
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
|
split: str = "train"
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
num_samples: Optional[int] = None
|
num_samples: Optional[int] = None
|
||||||
""" common columns """
|
# common columns
|
||||||
system: Optional[str] = None
|
system: Optional[str] = None
|
||||||
tools: Optional[str] = None
|
tools: Optional[str] = None
|
||||||
images: Optional[str] = None
|
images: Optional[str] = None
|
||||||
""" rlhf columns """
|
videos: Optional[str] = None
|
||||||
|
# rlhf columns
|
||||||
chosen: Optional[str] = None
|
chosen: Optional[str] = None
|
||||||
rejected: Optional[str] = None
|
rejected: Optional[str] = None
|
||||||
kto_tag: Optional[str] = None
|
kto_tag: Optional[str] = None
|
||||||
""" alpaca columns """
|
# alpaca columns
|
||||||
prompt: Optional[str] = "instruction"
|
prompt: Optional[str] = "instruction"
|
||||||
query: Optional[str] = "input"
|
query: Optional[str] = "input"
|
||||||
response: Optional[str] = "output"
|
response: Optional[str] = "output"
|
||||||
history: Optional[str] = None
|
history: Optional[str] = None
|
||||||
""" sharegpt columns """
|
# sharegpt columns
|
||||||
messages: Optional[str] = "conversations"
|
messages: Optional[str] = "conversations"
|
||||||
""" sharegpt tags """
|
# sharegpt tags
|
||||||
role_tag: Optional[str] = "from"
|
role_tag: Optional[str] = "from"
|
||||||
content_tag: Optional[str] = "value"
|
content_tag: Optional[str] = "value"
|
||||||
user_tag: Optional[str] = "human"
|
user_tag: Optional[str] = "human"
|
||||||
@@ -71,45 +71,55 @@ class DatasetAttr:
|
|||||||
setattr(self, key, obj.get(key, default))
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||||
if data_args.dataset is not None:
|
r"""
|
||||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
|
Gets the attributes of the datasets.
|
||||||
else:
|
"""
|
||||||
|
if dataset_names is None:
|
||||||
dataset_names = []
|
dataset_names = []
|
||||||
|
|
||||||
if data_args.dataset_dir == "ONLINE":
|
if dataset_dir == "ONLINE":
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
else:
|
else:
|
||||||
|
if dataset_dir.startswith("REMOTE:"):
|
||||||
|
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
||||||
|
else:
|
||||||
|
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
with open(config_path) as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if len(dataset_names) != 0:
|
if len(dataset_names) != 0:
|
||||||
raise ValueError(
|
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
|
||||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
|
||||||
)
|
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
if data_args.interleave_probs is not None:
|
dataset_list: List["DatasetAttr"] = []
|
||||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
|
||||||
|
|
||||||
dataset_list: List[DatasetAttr] = []
|
|
||||||
for name in dataset_names:
|
for name in dataset_names:
|
||||||
if dataset_info is None:
|
if dataset_info is None: # dataset_dir is ONLINE
|
||||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
if use_modelscope():
|
||||||
|
load_from = "ms_hub"
|
||||||
|
elif use_openmind():
|
||||||
|
load_from = "om_hub"
|
||||||
|
else:
|
||||||
|
load_from = "hf_hub"
|
||||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||||
dataset_list.append(dataset_attr)
|
dataset_list.append(dataset_attr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
|
||||||
|
|
||||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||||
|
has_om_url = "om_hub_url" in dataset_info[name]
|
||||||
|
|
||||||
if has_hf_url or has_ms_url:
|
if has_hf_url or has_ms_url or has_om_url:
|
||||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
if has_ms_url and (use_modelscope() or not has_hf_url):
|
||||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||||
|
elif has_om_url and (use_openmind() or not has_hf_url):
|
||||||
|
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
|
||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
elif "script_url" in dataset_info[name]:
|
elif "script_url" in dataset_info[name]:
|
||||||
@@ -120,11 +130,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
dataset_attr.set_attr("subset", dataset_info[name])
|
dataset_attr.set_attr("subset", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
||||||
dataset_attr.set_attr("folder", dataset_info[name])
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
|
||||||
if dataset_attr.formatting == "alpaca":
|
if dataset_attr.formatting == "alpaca":
|
||||||
column_names.extend(["prompt", "query", "response", "history"])
|
column_names.extend(["prompt", "query", "response", "history"])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
@@ -35,11 +35,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
|
do_generate: bool = False,
|
||||||
) -> Tuple[Callable, Callable]:
|
) -> Tuple[Callable, Callable]:
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
@@ -48,12 +48,26 @@ def get_preprocess_and_print_func(
|
|||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not do_generate:
|
||||||
if data_args.packing:
|
if data_args.packing:
|
||||||
|
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||||
|
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||||
|
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
return TypedSequence.__init__(
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
type=kwargs.pop("type", None),
|
||||||
|
try_type=kwargs.pop("try_type", None),
|
||||||
|
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
OptimizedTypedSequence.__init__ = __init__
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_packed_supervised_dataset,
|
preprocess_packed_supervised_dataset,
|
||||||
template=template,
|
template=template,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user