159 Commits

Author SHA1 Message Date
hiyouga
dfff411e1a release v0.9.0 (real)
Former-commit-id: 8ff781c8ae5654680f738f69a6db9d7b95d76baf
2024-09-09 01:00:25 +08:00
hiyouga
e20baa4218 fix constants
Former-commit-id: fce6671d2764d7a2b77c44401fc5582c7cbb77aa
2024-09-08 23:52:30 +08:00
hiyouga
d1ab9b501a release v0.9.0
Former-commit-id: 594c450f648ad326ef39c0f4d70d67cda5f36159
2024-09-08 23:43:35 +08:00
hiyouga
3cbc9109ea tiny fix
Former-commit-id: 76177039c8f9ef5a63724a339dae6195d89fa215
2024-09-08 23:18:08 +08:00
hiyouga
3259397f89 update scripts
Former-commit-id: 51d087cbc14bf3c7dfa06b8b66052cd80a6081be
2024-09-08 14:17:41 +08:00
hiyouga
eb5af3d90b support vllm 0.6.0
Former-commit-id: e39470ec51a9c74ad901871eb816df10e851f351
2024-09-08 02:26:20 +08:00
hiyouga
b6810b209a fix test case
Former-commit-id: b075b2971c6acb2c6039b36420a296f1f4e1b91b
2024-09-08 01:50:51 +08:00
hiyouga
158e0e1f63 add test case
Former-commit-id: c452d65e1551074dddd1d87517c0d44dc014c6aa
2024-09-08 01:40:49 +08:00
hiyouga
294a103ead support activation offloading via unsloth gc
Former-commit-id: d3d0dd0feba3ca6f0ae970d5856bec989d26ef67
2024-09-08 01:22:19 +08:00
hiyouga
7f71276ad8 add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
2024-09-08 00:56:56 +08:00
hoshi-hiyouga
93d4570a59 Merge pull request #5388 from yzoaim/cal_mfu_update
update cal_mfu.py

Former-commit-id: fe5eac2cb6a4646b653232d7c68d535105b60f3a
2024-09-08 00:49:28 +08:00
hoshi-hiyouga
527ba2eb2e fix
Former-commit-id: 53a74cbc3afec58b36c2265e080061bcdf702f98
2024-09-08 00:41:45 +08:00
hoshi-hiyouga
3021b31cf3 Update cal_mfu.py
Former-commit-id: 0c391b2e59943b0ca9dd4e8561398e7c856a4b29
2024-09-08 00:39:48 +08:00
-.-
9f2427907e update cal_mfu.py
Former-commit-id: 1cdbb4c774d463969c6be14fb08d92c7a0bdb565
2024-09-07 23:21:35 +08:00
hoshi-hiyouga
570ce100c1 fix #5384
Former-commit-id: 2e86c54f381f7403c30ba78d2acf5003aab6e049
2024-09-07 01:21:14 +08:00
hiyouga
27547355e6 tiny fix
Former-commit-id: c0e9c0484dae6db93cef5048bad827ff22b1986a
2024-09-05 23:41:16 +08:00
hiyouga
c5ef52a67a fix ci
Former-commit-id: b5ffca5a190f3aed8ba8c49bd8cf3239fb787bf5
2024-09-05 22:39:47 +08:00
hiyouga
b48b47d519 fix ci
Former-commit-id: cf0758b03e9b8b4931ba790a9726b8256ee4286c
2024-09-05 22:27:48 +08:00
hiyouga
9bdba2f6a8 add e2e tests
Former-commit-id: 0156a37450604641c4f5f9756ad84324698fc88c
2024-09-05 21:52:28 +08:00
hoshi-hiyouga
d6ce902d80 Merge pull request #5372 from LDLINGLINGLING/main
增加了对minicpm3.0的适配'

Former-commit-id: 2e3c221d9c87bd59f48648be8878b7b50347280f
2024-09-05 21:35:42 +08:00
liudan
ce6dcf3600 根据代码规范修改了代码
Former-commit-id: fe5351980b42e0e38175b0da2401a61b3807fa7c
2024-09-05 20:17:55 +08:00
hoshi-hiyouga
e7f92d16d8 fix #5366
Former-commit-id: b0a4964846dd5be7aa2c54d43f28ba62985587f1
2024-09-05 18:08:09 +08:00
hiyouga
abd26f5f67 update data readme
Former-commit-id: 0af5f054b7b8da8b39eb44b1dfa76050f0c45667
2024-09-05 04:44:49 +08:00
hiyouga
4d35ace75e update data readme
Former-commit-id: 81adb153b7d0b30e6cd50c9bf4ca1ccf17458611
2024-09-05 04:25:27 +08:00
hiyouga
72222d1598 support Yi-Coder models
Former-commit-id: ea3f1659e70541c4fa8b7079a0a8c94fce9a41c8
2024-09-05 03:12:24 +08:00
hiyouga
26d914b8fc fix ci
Former-commit-id: 280c0f3f2cea4dfced797cc0e15f72b8b3a93542
2024-09-05 03:02:59 +08:00
hiyouga
7b01c0676c fix ci
Former-commit-id: 7899b44b19c3d0a70706d987bb7d2e0e3536014b
2024-09-05 02:49:22 +08:00
hiyouga
571a9b8669 update ci
Former-commit-id: e24bf7345442701ca874d439f0ca3da49fa59a84
2024-09-05 02:26:10 +08:00
hoshi-hiyouga
ed35eb1e9e Merge pull request #5365 from hiyouga/video_finetuning
Support Qwen2-VL Fine-Tuning on Video Datasets

Former-commit-id: 178cc3fbc48bf2c68723b487681db04e660b12fa
2024-09-05 02:24:58 +08:00
hiyouga
d291e0d60d tiny fix
Former-commit-id: 9da6e084e1e5daf7403e7fabeaaec686167fb11f
2024-09-05 02:16:49 +08:00
hiyouga
1874d579c5 video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
2024-09-05 02:04:17 +08:00
liudan
c692339020 增加了对minicpm3.0的适配'
Former-commit-id: 4ad3a761af2452ef3f6c61190b7e47c9ea5227b9
2024-09-04 23:10:05 +08:00
hiyouga
2c1eef34cb fix test
Former-commit-id: 553a83aff9f9da35c9a0eca81f7d2b0bf2adf6ff
2024-09-04 22:38:26 +08:00
hiyouga
af178cbcd1 update get template
Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
2024-09-04 22:36:20 +08:00
hoshi-hiyouga
5d85be31ca Merge pull request #5323 from naem1023/feat/add-dataset-map-batch-size-argument
Add batch size of map function in the preprocessed dataset

Former-commit-id: c3428c5807500d087cdee4386798e10e39c9cf30
2024-09-04 22:09:36 +08:00
hoshi-hiyouga
372b71c847 fix #5228
Former-commit-id: 0d332ca8d0987c0331361934ab110fafa6402a7e
2024-09-04 19:10:30 +08:00
hiyouga
41a9c415e1 fix #5252
Former-commit-id: 73f30b4dfffb260e24f9e2332617b8ca2c249ed5
2024-09-04 03:17:54 +08:00
hiyouga
915e32a5f8 add vl_feedback dataset
Former-commit-id: 6ff34ad2db383b5fbd51008bcc5eec880658811e
2024-09-04 03:13:03 +08:00
hiyouga
f4dd429cbf fix #5344
Former-commit-id: 9d445c0b5be5ccc0e6d1979e76a869ddf92d9534
2024-09-04 03:06:06 +08:00
hoshi-hiyouga
7435cde2ef Merge pull request #5346 from hiyouga/lazy_image
[exp] Lazyload for multimodal inputs

Former-commit-id: 4bbd721361a8c5888b28f5fcdcbb2a4ad2305445
2024-09-04 03:00:53 +08:00
hiyouga
7056087e92 lazy image load
Former-commit-id: cdd733b575411e003bc5ffd6560dd8eff8aa09cf
2024-09-04 02:27:08 +08:00
hiyouga
fed7ae5661 fix #5334
Former-commit-id: a5ea0f83f00c81d128a1f50ce244866ce38ee15f
2024-09-03 19:09:42 +08:00
hiyouga
5019c6148b fix #5338
Former-commit-id: a66ddfea218feefde50fa097d20b4bcbe89ab791
2024-09-03 17:45:17 +08:00
hiyouga
2e1396cd6b lint
Former-commit-id: d821d933e6cb982d648a69f85f6ad01d0560ed70
2024-09-03 00:46:25 +08:00
hiyouga
b5e9df5df8 fix #5324
Former-commit-id: f7aa06c9c0b18c28419ea5792410915d3f322cbf
2024-09-02 23:56:21 +08:00
naem1023
3622856994 feat: add batch size of map function in the preprocessed dataset
Former-commit-id: 94b6cf06c2f84d0619b1a2dccaf8abb51de9951c
2024-09-02 13:52:47 +09:00
hoshi-hiyouga
7367c6ec21 fix trainer predict
Former-commit-id: 2790790cd26c6743105555a60523b89f367ebce3
2024-09-02 10:15:29 +08:00
hoshi-hiyouga
6579ec8c4c remove .cpu()
Former-commit-id: 35c57cc9dcba305d40282a9757ddc23968c210ac
2024-09-02 10:10:53 +08:00
hiyouga
a7fbae47d5 fix mm inference
Former-commit-id: fa782c15a07ed40f8a6381acdf2da395377efd04
2024-09-02 01:47:40 +08:00
hiyouga
f203a9d78e tiny fix
Former-commit-id: 8b4f408da110d74285bae20bbd969013a979964b
2024-09-02 01:33:22 +08:00
hiyouga
bae73e676c add image num check
Former-commit-id: 15201113bf16b748c0a758c7a5b363da8272e0e6
2024-09-02 01:31:36 +08:00
hiyouga
806e1061d4 add pokemon dataset
Former-commit-id: 06680158a0f0a1e3c542e77af92ac877fbe357c5
2024-09-02 01:02:25 +08:00
hiyouga
f920091667 update readme
Former-commit-id: 25a05d9f96718e06ce83f5bb1f41d2c001790295
2024-09-01 23:32:39 +08:00
hiyouga
801979f779 update wechat
Former-commit-id: 7f88dfe080db10ff12d1fb80b43099a356c899ea
2024-09-01 23:30:57 +08:00
hoshi-hiyouga
df2d32e7aa Merge pull request #5317 from ByronHsu/patch-1
Add liger kernel link

Former-commit-id: a319b3cf9119fd49cbcfb17b963e111a2f86bb51
2024-09-01 23:30:12 +08:00
hiyouga
60cf12727b add rlhf-v dataset
Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
2024-09-01 22:57:41 +08:00
hiyouga
7621526d22 tiny fix
Former-commit-id: 8ccaae3871d8d1fe3ea4633d427aecb2ab3addec
2024-09-01 21:15:44 +08:00
hiyouga
559b84dceb fix bug
Former-commit-id: 6e19e56000dd18d5faf84ceabce8d7708ff21e4d
2024-09-01 21:07:49 +08:00
hiyouga
7e4c5d4bb3 fix mixed mm inputs and rlhf-v
Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
2024-09-01 20:52:47 +08:00
Byron Hsu
2a4ed6610e Add liger kernel link
Former-commit-id: 4f313044cf8efd9c6ebcbd4741f6f38d56804b7f
2024-08-30 17:16:16 -07:00
hiyouga
1d8e9c7897 fix ci (temp)
Former-commit-id: 9ebaafd2e5c16ecef0243e4df77344ed7c823e57
2024-08-31 02:03:56 +08:00
hiyouga
43654028eb add test mm plugin
Former-commit-id: ddea5cca5a3174de1dcc7fdee8ec69e77700b6bf
2024-08-31 01:53:38 +08:00
hiyouga
2f6fc27c8b remove visual_inputs, fix qlora
Former-commit-id: be30c01c4f1482520ece770bd54c6a4837c26f0a
2024-08-31 00:24:51 +08:00
hiyouga
d789b667d7 optimize predict vram
Former-commit-id: a577e44eee351b3ed8011a33ae01cd713354ff97
2024-08-30 23:08:45 +08:00
hiyouga
66a1abac6a add examples
Former-commit-id: 169c68921b1b8ac279834b060d9e7d38a56fe1aa
2024-08-30 21:43:19 +08:00
hiyouga
665db18661 tiny fix
Former-commit-id: 830511a6d0216da99520aee8b3a753d347a71fa9
2024-08-30 03:21:50 +08:00
hiyouga
30d97ca879 fix #5307
Former-commit-id: 63c19ddfe483a16c1c9afc2f1441e8070bb0f7e4
2024-08-30 02:45:40 +08:00
hiyouga
c62a6ca59d refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
2024-08-30 02:14:31 +08:00
hoshi-hiyouga
77c2c7076b Merge pull request #5290 from simonJJJ/qwen2_vl
support qwen2-vl

Former-commit-id: 7156f832af8505b26371559d340c0e69eb962bbc
2024-08-30 02:10:36 +08:00
hoshi-hiyouga
7466fd4387 fix bug
Former-commit-id: 365e6df71509569f59c40743c115f1a4b945ef0f
2024-08-30 02:05:26 +08:00
hiyouga
c1369a1ec9 update liger kernel
Former-commit-id: d6bf6ca2161c99dd5d644e31d2b1df451017b68c
2024-08-29 20:46:08 +08:00
hiyouga
d677fe053d fix #5292
Former-commit-id: dd81ce8ce5fdf450027c5f9634abb6ac2cd52128
2024-08-29 20:37:47 +08:00
hiyouga
7c6785d3df fix #5295
Former-commit-id: c76873b0eb8225f6e6bfc7223c6012387dceb8ed
2024-08-29 20:30:18 +08:00
hiyouga
77341ee3c4 fix #5305
Former-commit-id: a710ebaf97c258c802f24e508d83f1f3f10edc6d
2024-08-29 20:16:01 +08:00
simonJJJ
5b4b60cfb5 update
Former-commit-id: a968a416d5e513320c97109229ca1e6ddc003cb1
2024-08-28 20:22:46 +08:00
simonJJJ
0f3d54d8a0 initial-commit
Former-commit-id: b6a39847a10b417b09db4b5512dd835e9e4ce928
2024-08-28 16:51:35 +08:00
hiyouga
7272792f65 update wechat
Former-commit-id: ef91752cc6f53088eaf7fc2f64f7148821d82ec2
2024-08-27 12:55:23 +08:00
hiyouga
4cc8e16595 add extra requires
Former-commit-id: c47511773ae9886aae4e5ea1841866d2125abc34
2024-08-27 12:52:12 +08:00
hiyouga
ca5a759f94 tiny fix
Former-commit-id: d2cede7023bbe28525ef8b4ad27247445d8c22e5
2024-08-27 12:49:32 +08:00
hoshi-hiyouga
be51e56a2e Merge pull request #5237 from marko1616/patch-1
Fix mllm api

Former-commit-id: 017703c7ab7f3dc566792619537c3202ca4f4bb7
2024-08-27 12:24:43 +08:00
marko1616
3a9171e275 ruff pass.
Former-commit-id: c2f817772f8e7d947dca04f546befc70001abe64
2024-08-27 11:30:16 +08:00
marko1616
bd0f3b4050 Update chat.py
Former-commit-id: 4e5893a5c4a47ff3cb989bbef0841effc713fc08
2024-08-27 11:27:56 +08:00
hiyouga
206a8364d4 support liger kernel
Former-commit-id: 0f4e54abf6c5feb2329855a4047597ad5147720a
2024-08-27 11:20:14 +08:00
marko1616
097d031066 Force re check.
Former-commit-id: 5f04452f7d65e535d0af08944f7b9e29e85f51d7
2024-08-23 14:43:18 +08:00
marko1616
2674b42b59 Update chat.py
Former-commit-id: 206a16c17d253956afb96daea6f24478e17334fc
2024-08-22 12:24:34 +08:00
marko1616
edf2e51bbc Update chat.py
Former-commit-id: edf6dc1995daa6c3635c3fda1052b340693a04f5
2024-08-22 12:14:34 +08:00
MengqingCao
47877acc2a update npu base image
Former-commit-id: 20819f7707cfff6b951484e91fc7ecda2bf68528
2024-08-21 09:12:38 +00:00
hiyouga
d111a324bc tiny fix
Former-commit-id: 23961bdf6fdbcde64e7b943f699fdeb4ac024043
2024-08-20 00:10:52 +08:00
hoshi-hiyouga
388f0a6e05 Merge pull request #5156 from YeQiuO/main
fix Llama-template's system prompt bug

Former-commit-id: 0b57175d3bd029675dae2f55995b7eeb4e9adc7a
2024-08-20 00:09:03 +08:00
hoshi-hiyouga
8c13c02c55 Update template.py
Former-commit-id: f5a075cb1c90f05bb0de26c6aea718f556c54623
2024-08-20 00:03:33 +08:00
hoshi-hiyouga
a101fde917 Merge pull request #5163 from liu-zichen/fix_ppo_optim
fix lr not change

Former-commit-id: f3c03ec6a89bf57f290820fa31eda24291355e4e
2024-08-19 23:56:24 +08:00
hoshi-hiyouga
1f4373b6e5 Merge pull request #5185 from chenhuiyu/feature/add-sailorllm-template
Add SailorLLM template

Former-commit-id: 28387d6b2f9e3bcc6321345c46b525c8180ebf7e
2024-08-19 23:51:49 +08:00
hoshi-hiyouga
525747b472 Merge pull request #5188 from Zxilly/main
fix: report correct device count for intel xpu
Former-commit-id: cd3c536cb3936061d905256850b0e57df4498010
2024-08-19 23:51:39 +08:00
hoshi-hiyouga
472f12c985 Merge pull request #5193 from Ricardo-L-C/main
_is_bf16_available judgment supports npu

Former-commit-id: 18b9ac49c45af773a2ea563f5e1852dc4b775db8
2024-08-19 23:40:59 +08:00
hoshi-hiyouga
b681f24f43 Update template.py
Former-commit-id: c6822a217e1c296f4aedd9a2c7610acd1dbd443e
2024-08-19 23:40:16 +08:00
hiyouga
fd02b089b6 update readme
Former-commit-id: 756e438866876fa54495cf557dd1e299b17a42fb
2024-08-19 23:32:04 +08:00
Ricardo
57d4c4a4f8 _is_bf16_available judgment supports npu
Former-commit-id: 50a1e892a1005b4cdd82dca1005f71db08ed89a2
2024-08-16 02:58:22 +00:00
Zxilly
3595d26846 fix: report correct device count for intel xpu
Former-commit-id: 0618f660b6511599365bd9be64499dbab41a79ba
2024-08-15 08:30:43 +00:00
Huiyu Chen
22a79c169d Add SailorLLM template
Former-commit-id: a594abe0321a718394a97b5a48ded16e2012c1f0
2024-08-15 15:10:14 +08:00
liu-zichen
75dfe259cf fix lr not change
Former-commit-id: 387dd2d51b5d8cd666459040fdd16525b34720d9
2024-08-13 16:33:34 +08:00
codingma
2e257d6af0 add tutorial and doc links
Former-commit-id: 4f6072562a34e0ec97471210ff54244cf0d0f3df
2024-08-13 16:13:10 +08:00
“Wzw”
e734222373 fix Llama-template's system prompt bug
Former-commit-id: 2e3eddcd0918b0c968ded0df7c82e3dcff870381
2024-08-12 19:22:12 +08:00
hiyouga
6a351b9912 update readme
Former-commit-id: 4fecc5ee56873a7ab4941e46a5168cfe2ecb4bb6
2024-08-10 10:17:35 +08:00
hiyouga
cfc04aa162 update readme
Former-commit-id: fa7bc9f1c7347153f9092ffbbb8e88c6b2f59632
2024-08-09 20:46:02 +08:00
hiyouga
943c795318 add magpie ultra dataset
Former-commit-id: 3317b24329b87e30f13a78936ac5554f211abf7a
2024-08-09 20:28:55 +08:00
hiyouga
7fb61bad04 add qwen2 math models
Former-commit-id: 72ff43a1772c9de5ff914d5e1c8bdc8dea9ae0c8
2024-08-09 20:20:35 +08:00
hiyouga
47efcdb1dd update examples
Former-commit-id: d5c57c8b7f64afe8061045ec9689abbac45c1175
2024-08-09 20:13:46 +08:00
hiyouga
59cbce1a46 add adam_mini to readme
Former-commit-id: d610c6bcf8a8ba6f4236f5d11f79571b83f4fb11
2024-08-09 20:02:03 +08:00
hoshi-hiyouga
7e755e9cac Merge pull request #5095 from relic-yuexi/feat-optimizer
Feat optimizer

Former-commit-id: f08390d252d42a812b71a08daba7339cc40889b7
2024-08-09 19:51:33 +08:00
hiyouga
9d1e2c3c1f update scripts
Former-commit-id: dabf5a1dc661a6581474c6a5ec115322d168ed5f
2024-08-09 19:16:23 +08:00
hiyouga
5af32ce705 follow #5115
Former-commit-id: 7d917e03e2df570139bae18227d9c7303a12de2a
2024-08-09 18:03:00 +08:00
hoshi-hiyouga
4e8861e653 Merge pull request #5115 from YeQiuO/main
fix: `Train on the last turn only` truncate bug
Former-commit-id: 2c6dae45f7a7b72c961489ac407b1b444ab7752e
2024-08-09 17:58:27 +08:00
hoshi-hiyouga
d4d7ffb17c Merge pull request #5072 from relic-yuexi/main
fix the deepseekcoder template to avoid repeat problem

Former-commit-id: 2ae7d5c91725eab9f994015d8d3577894c7978b6
2024-08-09 16:35:21 +08:00
hoshi-hiyouga
46f834ec75 Update template.py
Former-commit-id: ae2a5221c109ae3474d219c37433be767abbee91
2024-08-09 16:27:42 +08:00
“Wzw”
6ec64a7e56 mask_history args verify valid
Former-commit-id: 2f8388b4f4195d934400ad9267d72e10ca4105a3
2024-08-08 10:12:01 +08:00
“Wzw”
d71446e387 fix mask_history tiny bug
Former-commit-id: cac07aac6196be026f723b2397a343d4fb675973
2024-08-08 10:09:33 +08:00
codingma
eada49e56b fix eval_dataset in example
Former-commit-id: e1ffc54f7e58419cc8da958a4d3c2697e18d5583
2024-08-07 18:24:19 +08:00
moontidef
8f42d7df56 feat: add support for adammini
Former-commit-id: a2d5fafb705ff44db1711e972490f0abebc2012b
2024-08-07 10:08:22 +08:00
moontidef
33a90b9026 fix: rename optimzer to optimizer
Former-commit-id: 186dc1fde822e6a603ac273538741ea3853f243e
2024-08-07 10:05:01 +08:00
moontidef
710902b0d0 Merge branch 'hiyouga:main' into main
Former-commit-id: d1b23283e0e4286f126d38d7bdc55802f74c8922
2024-08-06 00:18:45 +08:00
moontidef
7b4f5d3b21 fix: fix the deepseekcoder template to avoid repeat problem
Former-commit-id: 56294831115f095135f72490a8a435434b2f0a11
2024-08-05 23:55:45 +08:00
hiyouga
13093963b1 fix #5048
Former-commit-id: 71a6861667ae68c1fd6a69acf68e1359b858cf1b
2024-08-05 23:48:19 +08:00
hoshi-hiyouga
2e477e7458 Merge pull request #5037 from codemayq/feature-gemma-2-2b
support gemma-2-2b

Former-commit-id: 6af51fadff92cd3e665c556ac073a1876f792ada
2024-08-05 23:27:37 +08:00
codingma
4b6252151e support gemma-2-2b
Former-commit-id: 7037192cf6049fd7d675aed4a6237ed929c6b170
2024-08-01 13:45:48 +08:00
hoshi-hiyouga
f3765d1996 Merge pull request #5010 from Eruly/main
Add Korean web UI (llamafactory-cli webui)

Former-commit-id: 2050806aa826028df45c0c746b4314afe178dcd3
2024-07-30 01:55:54 +08:00
hoshi-hiyouga
1f5cdd66b7 Merge pull request #4996 from LDLINGLINGLING/main
增加了MiniCPM在页面首页的支持列表,MiniCPM官方github也放了LLama_factory的友情链接

Former-commit-id: a86a776fb0f75697b0fee7694a5a0d6bd04fee0a
2024-07-30 01:55:30 +08:00
hoshi-hiyouga
5b0ddbb835 Update README_zh.md
Former-commit-id: 922906faf2d432def7cfdac82f90472fa1bb24a9
2024-07-30 01:55:13 +08:00
hoshi-hiyouga
4f92b56f06 Update README.md
Former-commit-id: 6bc7f71940be0a8f1614f9036b9c539ce46d34e1
2024-07-30 01:53:19 +08:00
hoshi-hiyouga
a1f6ff92be Update README.md
Former-commit-id: 54eecdec0da06677ea55847c74642d0fc12d8908
2024-07-30 01:52:35 +08:00
hoshi-hiyouga
ef98e91618 Merge pull request #4995 from codemayq/fix-pissa
fix pissa callback

Former-commit-id: 052c0f6bd9e872ea325b5a6aef98c4c070733384
2024-07-30 01:47:25 +08:00
eruly
9fdf800750 Add Korean web UI (llamafactory-cli webui)
Former-commit-id: 357a035f2aeb9548368c230c5a17dcdfa4844b17
2024-07-29 13:47:13 +00:00
liudan
32c698e4c2 增加了MiniCPM在页面首页的支持列表,MiniCPM官方github也放了LLama_factory的友情链接
Former-commit-id: f482a6e2fd30aff5113e53f3f07b4649982bcc2e
2024-07-29 10:58:28 +08:00
codingma
75e80fa820 fix pissa save
Former-commit-id: 25a1dad7c8df79c15efecb8c6f871a13a327f57a
2024-07-29 10:44:34 +08:00
hiyouga
f8329bc632 tiny fix
Former-commit-id: 183d8bd500a8e9513a077161ba8e8d61bea9200f
2024-07-26 11:51:00 +08:00
hoshi-hiyouga
9f74d36ba4 Merge pull request #4892 from piamo/main
update deepseek template

Former-commit-id: 3233efc8404972098665286d9dec7312dd6ecfab
2024-07-26 11:49:34 +08:00
hoshi-hiyouga
fc2435f135 Merge pull request #4950 from liuwwang/main and fi
fix: Repair the issue where quantization failed after merging the adapter.
Former-commit-id: 93a68ea1f4372973f745a2c250250ecaac515e27
2024-07-26 11:48:56 +08:00
hoshi-hiyouga
0636519ba3 Merge pull request #4970 from HardAndHeavy/add-rocm
Add ROCm support

Former-commit-id: c0f21d869bce6e59825d57c66bce3fe54f50065f
2024-07-26 11:41:23 +08:00
hoshi-hiyouga
573bf03a6f Update README_zh.md
Former-commit-id: 86a27a97ff67b0d4bcd671c62759cd049542dc1b
2024-07-26 11:30:57 +08:00
hoshi-hiyouga
9e529be4e7 Update README.md
Former-commit-id: 1c167bb2ea3a47bdeeccc044a653662132c61698
2024-07-26 11:29:28 +08:00
hoshi-hiyouga
7af4ffa6cc Update README.md
Former-commit-id: d6e7a69c274c3756587e18a039637dd37fa152b2
2024-07-26 11:29:09 +08:00
HardAndHeavy
5b67ccd1c6 Add ROCm support
Former-commit-id: cf9df10a24936efd420b0fdac541fd6c0808a327
2024-07-25 21:29:28 +03:00
khazic
5166dbbcd3 Added the reference address for TRL PPO details.
Former-commit-id: 509c55608643eae3a6456683d425a7c636cfc3e9
2024-07-25 09:03:21 +08:00
hiyouga
21adb09730 fix #4959
Former-commit-id: 96e8a1d47874708c6157865c78be4cd6c533e01b
2024-07-24 23:44:00 +08:00
hiyouga
28b5f656db update webui
Former-commit-id: 463edec1b1c1345afc791e225deb33f118f3582e
2024-07-24 21:11:51 +08:00
hoshi-hiyouga
68ee2d512f Update README_zh.md
Former-commit-id: 1443e876697e18108573387e501a7453ba9fc06c
2024-07-24 21:08:42 +08:00
hoshi-hiyouga
a5f7e0efc6 Update README.md
Former-commit-id: 07d86e38cfd857d1dfa898541f3e5bd9c6f11581
2024-07-24 21:07:14 +08:00
hiyouga
211038584a tiny fix
Former-commit-id: 28cac0e325bfd7a6c0c344ad2d46511613190cd7
2024-07-24 18:33:39 +08:00
hiyouga
ff5ba97970 fix #4928
Former-commit-id: 6d557e8959678f9d4edbcb3d5a6dfba14b429b18
2024-07-24 17:00:29 +08:00
hiyouga
27f2c3cae1 fix #4925
Former-commit-id: 79c336e2339974471627487858d59e4ed2152370
2024-07-24 16:56:58 +08:00
hiyouga
48f0819327 fix #4944
Former-commit-id: 9e8cf3b21a0b12d1413c3c7f3d60399784909242
2024-07-24 16:42:51 +08:00
hiyouga
5c6d88e91c add mistral nemo model
Former-commit-id: 428bb49f53b32947bc0a62ca19ab10844154c07c
2024-07-24 16:25:53 +08:00
hiyouga
0a04d9470f add llama3.1
Former-commit-id: 3c433890f9b61c520572f5233aae70584da0f330
2024-07-24 16:20:11 +08:00
Liuww
f0408c0dde fix: Repair the issue where quantization failed after merging the adapter.
Former-commit-id: 8109561b7f577d448f8bca7e569f7f443cf6bb52
2024-07-24 14:31:29 +08:00
hiyouga
a041f4a111 tiny fix
Former-commit-id: bf6a2f032c598f969708c1c3db4875d6239c41a9
2024-07-22 21:10:15 +08:00
hoshi-hiyouga
cdf9dae53e fix #4917
Former-commit-id: e26919aafd8436489d065789c9c25d72c8d05a6d
2024-07-22 11:28:31 +08:00
hiyouga
1917f431f5 tiny fix
Former-commit-id: 9133316e558a3c8744f5eb6ab8678686bf4859ed
2024-07-22 00:06:03 +08:00
hiyouga
a770afbff2 fix flashattn + packing
Former-commit-id: 4adc6ce4abc718c25f39b316bfc3352d0d01ed1e
2024-07-21 17:07:45 +08:00
huangpan.foo
b1a5bf025b update deepseek template
Former-commit-id: f5ca86ec95bb301df42ffaa6923fc3037a224e34
2024-07-19 15:02:54 +08:00
hiyouga
adff3e5050 set dev version
Former-commit-id: 0b9a2275dc533b65578278f979ce053e95a644b3
2024-07-19 02:01:46 +08:00
127 changed files with 4083 additions and 1316 deletions

35
.env.local Normal file
View File

@@ -0,0 +1,35 @@
# Note: actually we do not support .env, just for reference
# api
API_HOST=0.0.0.0
API_PORT=8000
API_KEY=
API_MODEL_NAME=gpt-3.5-turbo
FASTAPI_ROOT_PATH=
# general
DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS=
FORCE_TORCHRUN=
LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB=
RECORD_VRAM=
# torchrun
FORCE_TORCHRUN=
MASTER_ADDR=
MASTER_PORT=
NNODES=
RANK=
NPROC_PER_NODE=
# wandb
WANDB_DISABLED=
WANDB_PROJECT=huggingface
WANDB_API_KEY=
# gradio ui
GRADIO_SHARE=False
GRADIO_SERVER_NAME=0.0.0.0
GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH=
# setup
ENABLE_SHORT_CONSOLE=1
# reserved (do not use)
LLAMABOARD_ENABLED=
LLAMABOARD_WORKDIR=

View File

@@ -3,14 +3,14 @@ name: tests
on: on:
push: push:
branches: branches:
- main - "main"
paths: paths:
- "**.py" - "**.py"
- "requirements.txt" - "requirements.txt"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
branches: branches:
- main - "main"
paths: paths:
- "**.py" - "**.py"
- "requirements.txt" - "requirements.txt"
@@ -18,13 +18,27 @@ on:
jobs: jobs:
tests: tests:
runs-on: ubuntu-latest strategy:
fail-fast: false
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
os:
- "ubuntu-latest"
- "windows-latest"
- "macos-13"
runs-on: ${{ matrix.os }}
environment: environment:
name: tests name: tests
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
steps: steps:
- name: Checkout - name: Checkout
@@ -33,13 +47,14 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: ${{ matrix.python-version }}
cache: "pip" cache: "pip"
cache-dependency-path: "setup.py" cache-dependency-path: "setup.py"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install git+https://github.com/huggingface/transformers.git
python -m pip install ".[torch,dev]" python -m pip install ".[torch,dev]"
- name: Check quality - name: Check quality

2
.gitignore vendored
View File

@@ -160,6 +160,8 @@ cython_debug/
.idea/ .idea/
# custom .gitignore # custom .gitignore
ms_cache/
hf_cache/
cache/ cache/
config/ config/
saves/ saves/

View File

@@ -1,6 +1,6 @@
.PHONY: quality style test .PHONY: quality style test
check_dirs := scripts src tests check_dirs := scripts src tests setup.py
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)

184
README.md
View File

@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-72-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-91-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -21,13 +21,14 @@
**Fine-tuning a large language model can be easy as...** **Fine-tuning a large language model can be easy as...**
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6 https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
Choose your path: Choose your path:
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
## Table of Contents ## Table of Contents
@@ -46,11 +47,11 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -71,15 +72,23 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
[24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
<details><summary>Full Changelog</summary>
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary> [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
@@ -91,7 +100,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage. [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
@@ -103,7 +112,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage. [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage. [24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed. [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
@@ -151,32 +160,34 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Template | | Model | Model size | Template |
| ------------------------------------------------------------ | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@@ -200,6 +211,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!TIP]
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
## Provided Datasets ## Provided Datasets
<details><summary>Pre-training datasets</summary> <details><summary>Pre-training datasets</summary>
@@ -259,7 +273,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) - [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) - [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) - [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -276,6 +292,8 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -296,20 +314,20 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.11 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 | | torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.41.2 | | transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.4 | | trl | 0.8.6 | 0.9.6 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 | | vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.5.9 | | flash-attn | 2.3.0 | 2.6.3 |
### Hardware Requirement ### Hardware Requirement
@@ -338,7 +356,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@@ -422,16 +440,24 @@ For CUDA users:
```bash ```bash
cd docker/docker-cuda/ cd docker/docker-cuda/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
``` ```
For Ascend NPU users: For Ascend NPU users:
```bash ```bash
cd docker/docker-npu/ cd docker/docker-npu/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
```
For AMD ROCm users:
```bash
cd docker/docker-rocm/
docker compose up -d
docker compose exec llamafactory bash
``` ```
<details><summary>Build without Docker Compose</summary> <details><summary>Build without Docker Compose</summary>
@@ -493,13 +519,42 @@ docker run -dit \
docker exec -it llamafactory bash docker exec -it llamafactory bash
``` ```
For AMD ROCm users:
```bash
docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \
-p 8000:8000 \
--device /dev/kfd \
--device /dev/dri \
--shm-size 16G \
--name llamafactory \
llamafactory:latest
docker exec -it llamafactory bash
```
</details> </details>
<details><summary>Details about volume</summary> <details><summary>Details about volume</summary>
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory. - `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI. - `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
- output: Set export dir to this location so that the merged result can be accessed directly on the host machine. - `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
</details> </details>
@@ -510,7 +565,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP] > [!TIP]
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document. > Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
### Download from ModelScope Hub ### Download from ModelScope Hub
@@ -600,7 +655,26 @@ If you have a project that should be incorporated, please contact via email or c
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) 1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) 1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) 1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburghs Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) 1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
@@ -618,7 +692,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-72-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-91-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -21,13 +21,15 @@
**微调大模型可以像这样轻松…** **微调大模型可以像这样轻松…**
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd-d76c6d0a6594 https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
选择你的打开方式: 选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **PAI-DSW**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
## 目录 ## 目录
@@ -46,11 +48,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色 ## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -71,15 +73,23 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
<details><summary>展开日志</summary>
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary> [24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `paligemma` 模板进行微调使其获得对话能力。
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
@@ -91,7 +101,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
@@ -103,7 +113,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。
@@ -151,32 +161,34 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型 ## 模型
| 模型名 | 模型大小 | Template | | 模型名 | 模型大小 | Template |
| ------------------------------------------------------------ | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
@@ -200,6 +212,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!TIP]
> 有关 PPO 的实现细节,请参考[此博客](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html)。
## 数据集 ## 数据集
<details><summary>预训练数据集</summary> <details><summary>预训练数据集</summary>
@@ -259,7 +274,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) - [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) - [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) - [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -276,6 +293,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -296,20 +315,20 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.11 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 | | torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.41.2 | | transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.4 | | trl | 0.8.6 | 0.9.6 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 | | vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.5.9 | | flash-attn | 2.3.0 | 2.6.3 |
### 硬件依赖 ### 硬件依赖
@@ -338,7 +357,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@@ -422,16 +441,24 @@ CUDA 用户:
```bash ```bash
cd docker/docker-cuda/ cd docker/docker-cuda/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
``` ```
昇腾 NPU 用户: 昇腾 NPU 用户:
```bash ```bash
cd docker/docker-npu/ cd docker/docker-npu/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
```
AMD ROCm 用户:
```bash
cd docker/docker-rocm/
docker compose up -d
docker compose exec llamafactory bash
``` ```
<details><summary>不使用 Docker Compose 构建</summary> <details><summary>不使用 Docker Compose 构建</summary>
@@ -493,13 +520,42 @@ docker run -dit \
docker exec -it llamafactory bash docker exec -it llamafactory bash
``` ```
AMD ROCm 用户:
```bash
docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \
-p 8000:8000 \
--device /dev/kfd \
--device /dev/dri \
--shm-size 16G \
--name llamafactory \
llamafactory:latest
docker exec -it llamafactory bash
```
</details> </details>
<details><summary>数据卷详情</summary> <details><summary>数据卷详情</summary>
- hf_cache使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。 - `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- data宿主机中存放数据集的文件夹路径 - `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供
- output将导出目录设置为该路径后即可在宿主机中访问导出后的模型 - `data`:宿主机中存放数据集的文件夹路径
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
</details> </details>
@@ -510,7 +566,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP] > [!TIP]
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。 > API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)
### 从魔搭社区下载 ### 从魔搭社区下载
@@ -600,7 +656,26 @@ run_name: test_run # 可选
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) 1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) 1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) 1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburghs Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) 1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
@@ -618,7 +693,7 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

View File

@@ -23,6 +23,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
"system": "the column name in the dataset containing the system prompts. (default: None)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)", "tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)", "images": "the column name in the dataset containing the image inputs. (default: None)",
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)", "chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)", "rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)" "kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
@@ -107,7 +108,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
### Preference Dataset ### Preference Dataset
Preference datasets are used for reward modeling, DPO training and ORPO training. Preference datasets are used for reward modeling, DPO training, ORPO and SimPO training.
It requires a better response in `chosen` column and a worse response in `rejected` column. It requires a better response in `chosen` column and a worse response in `rejected` column.
@@ -139,67 +140,15 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
### KTO Dataset ### KTO Dataset
- [Example dataset](kto_en_demo.json) An additional column `kto_tag` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
KTO datasets require a extra `kto_tag` column containing the boolean human feedback. ### Multimodal Image Dataset
```json An additional column `images` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"kto_tag": "human feedback [true/false] (required)"
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be: ### Multimodal Video Dataset
```json An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Dataset
- [Example dataset](mllm_demo.json)
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image.
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt Format ## Sharegpt Format
@@ -252,6 +201,10 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
``` ```
### Pre-training Dataset
Not yet supported, please use the [alpaca](#alpaca-format) format.
### Preference Dataset ### Preference Dataset
- [Example dataset](dpo_en_demo.json) - [Example dataset](dpo_en_demo.json)
@@ -302,6 +255,125 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
``` ```
### KTO Dataset
- [Example dataset](kto_en_demo.json)
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
```json
[
{
"conversations": [
{
"from": "human",
"value": "human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"kto_tag": "human feedback [true/false] (required)"
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Image Dataset
- [Example dataset](mllm_demo.json)
Multimodal image datasets require a `images` column containing the paths to the input images.
The number of images should be identical to the `<image>` tokens in the conversations.
```json
[
{
"conversations": [
{
"from": "human",
"value": "<image>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"images": "images"
}
}
```
### Multimodal Video Dataset
- [Example dataset](mllm_video_demo.json)
Multimodal video datasets require a `videos` column containing the paths to the input videos.
The number of videos should be identical to the `<video>` tokens in the conversations.
```json
[
{
"conversations": [
{
"from": "human",
"value": "<video>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"videos": [
"video path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"videos": "videos"
}
}
```
### OpenAI Format ### OpenAI Format
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt. The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
@@ -345,7 +417,3 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
} }
``` ```
The KTO datasets and multimodal datasets in sharegpt format are similar to the alpaca format.
Pre-training datasets are **incompatible** with the sharegpt format.

View File

@@ -23,6 +23,7 @@
"system": "数据集代表系统提示的表头名称默认None", "system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None", "tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None", "images": "数据集代表图像输入的表头名称默认None",
"videos": "数据集代表视频输入的表头名称默认None",
"chosen": "数据集代表更优回答的表头名称默认None", "chosen": "数据集代表更优回答的表头名称默认None",
"rejected": "数据集代表更差回答的表头名称默认None", "rejected": "数据集代表更差回答的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None" "kto_tag": "数据集代表 KTO 标签的表头名称默认None"
@@ -107,7 +108,7 @@
### 偏好数据集 ### 偏好数据集
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。 偏好数据集用于奖励模型训练、DPO 训练、ORPO 训练和 SimPO 训练。
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。 它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
@@ -139,67 +140,15 @@
### KTO 数据集 ### KTO 数据集
- [样例数据集](kto_en_demo.json) KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#sharegpt-格式)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。 ### 多模态图像数据集
```json 多模态图像数据集需要提供额外的 `images` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为: ### 多模态视频数据集
```json 多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### 多模态数据集
- [样例数据集](mllm_demo.json)
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt 格式 ## Sharegpt 格式
@@ -252,6 +201,10 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
``` ```
### 预训练数据集
尚不支持,请使用 [alpaca](#alpaca-格式) 格式。
### 偏好数据集 ### 偏好数据集
- [样例数据集](dpo_zh_demo.json) - [样例数据集](dpo_zh_demo.json)
@@ -302,6 +255,125 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
} }
``` ```
### KTO 数据集
- [样例数据集](kto_en_demo.json)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
```json
[
{
"conversations": [
{
"from": "human",
"value": "人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"kto_tag": "kto_tag"
}
}
```
### 多模态图像数据集
- [样例数据集](mllm_demo.json)
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
```json
[
{
"conversations": [
{
"from": "human",
"value": "<image>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"images": "images"
}
}
```
### 多模态视频数据集
- [样例数据集](mllm_video_demo.json)
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
```json
[
{
"conversations": [
{
"from": "human",
"value": "<video>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"videos": [
"视频路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"videos": "videos"
}
}
```
### OpenAI 格式 ### OpenAI 格式
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。 OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
@@ -345,7 +417,3 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
} }
} }
``` ```
Sharegpt 格式中的 KTO 数据集和多模态数据集与 alpaca 格式的类似。
预训练数据集**不支持** sharegpt 格式。

BIN
data/mllm_demo_data/1.mp4 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/2.avi Normal file

Binary file not shown.

BIN
data/mllm_demo_data/3.mp4 Normal file

Binary file not shown.

View File

@@ -1,9 +1,9 @@
# Use the Ubuntu 22.04 image with CANN 8.0.rc1 # Use the Ubuntu 22.04 image with CANN 8.0.rc1
# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags # More versions can be found at https://hub.docker.com/r/ascendai/cann/tags
# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04 # FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8
FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04 FROM ascendai/cann:8.0.rc1-910b-ubuntu22.04-py3.8
# FROM cosdt/cann:8.0.rc1-910-openeuler22.03 # FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8
# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03 # FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8
# Define environments # Define environments
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive

View File

@@ -0,0 +1,57 @@
FROM hardandheavy/transformers-rocm:2.1.0
# Define environments
ENV MAX_JOBS=4
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
# Define installation arguments
ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false
ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory
WORKDIR /app
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
# Copy the rest of the application into the image
COPY . /app
# Install the LLaMA Factory
RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_BNB" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
fi; \
if [ "$INSTALL_VLLM" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
fi; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi
# Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
# Expose port 7860 for the LLaMA Board
ENV GRADIO_SERVER_PORT 7860
EXPOSE 7860
# Expose port 8000 for the API service
ENV API_PORT 8000
EXPOSE 8000

View File

@@ -0,0 +1,29 @@
services:
llamafactory:
build:
dockerfile: ./docker/docker-rocm/Dockerfile
context: ../..
args:
INSTALL_BNB: false
INSTALL_VLLM: false
INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes:
- ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope
- ../../data:/app/data
- ../../output:/app/output
- ../../saves:/app/saves
ports:
- "7860:7860"
- "8000:8000"
ipc: host
tty: true
stdin_open: true
command: bash
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
restart: unless-stopped

View File

@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### Multimodal DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
``` ```
#### Reward Modeling #### Reward Modeling
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO Training #### KTO Training
```bash ```bash
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### Multimodal Supervised Fine-Tuning
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### Full-Parameter Fine-Tuning using Adam-mini
```bash
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
```
#### LoRA+ Fine-Tuning #### LoRA+ Fine-Tuning
```bash ```bash

View File

@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### 多模态 DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO 训练 #### KTO 训练
```bash ```bash
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### 多模态指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### 使用 Adam-mini 进行全参数训练
```bash
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
```
#### LoRA+ 微调 #### LoRA+ 微调
```bash ```bash

View File

@@ -1,27 +1,22 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen2-1.5B-Instruct
### method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
use_badam: true use_adam_mini: true
badam_mode: layer
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: qwen
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
### output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/qwen2-1_5b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
@@ -30,10 +25,12 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -10,6 +10,7 @@ badam_mode: layer
badam_switch_mode: ascending badam_switch_mode: ascending
badam_switch_interval: 50 badam_switch_interval: 50
badam_verbose: 2 badam_verbose: 2
# deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@@ -29,7 +30,7 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -29,11 +29,12 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -2,5 +2,5 @@
python scripts/llama_pro.py \ python scripts/llama_pro.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \ --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir models/llama3-8b-instruct-pro \ --output_dir models/llama3-8b-pro \
--num_expand 8 --num_expand 8

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: models/llama3-8b-instruct-pro model_name_or_path: models/llama3-8b-pro
### method ### method
stage: sft stage: sft
@@ -18,7 +18,7 @@ overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
### output ### output
output_dir: saves/llama3-8b-instruct-pro/freeze/sft output_dir: saves/llama3-8b-pro/freeze/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -26,7 +26,7 @@ overwrite_output_dir: true
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
optim: paged_adamw_8bit optim: paged_adamw_8bit
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -0,0 +1,5 @@
#!/bin/bash
python scripts/pissa_init.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir models/llama3-8b-pissa

View File

@@ -1,3 +1,2 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf model_name_or_path: llava-hf/llava-1.5-7b-hf
template: vicuna template: llava
visual_inputs: true

View File

@@ -0,0 +1,2 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
template: qwen2_vl

View File

@@ -0,0 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
template: qwen2_vl
finetuning_type: lora
### export
export_dir: models/qwen2_vl_lora_sft
export_size: 2
export_device: cpu
export_legacy_format: false

View File

@@ -7,7 +7,7 @@ do_predict: true
finetuning_type: full finetuning_type: full
### dataset ### dataset
dataset: identity,alpaca_en_demo eval_dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 50 max_samples: 50

View File

@@ -25,7 +25,7 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -0,0 +1,39 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: mllm_demo,identity
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -1,6 +1,5 @@
### model ### model
model_name_or_path: llava-hf/llava-1.5-7b-hf model_name_or_path: llava-hf/llava-1.5-7b-hf
visual_inputs: true
### method ### method
stage: sft stage: sft
@@ -10,7 +9,7 @@ lora_target: all
### dataset ### dataset
dataset: mllm_demo dataset: mllm_demo
template: vicuna template: llava
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true

View File

@@ -0,0 +1,41 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
pref_beta: 0.1
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: rlhf_v
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/lora/dpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -0,0 +1,39 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
dataset: mllm_demo,identity # video: mllm_video_demo
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -1,8 +1,8 @@
transformers>=4.41.2 transformers>=4.41.2,<=4.45.0
datasets>=2.16.0 datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1 accelerate>=0.30.1,<=0.33.0
peft>=0.11.1 peft>=0.11.1,<=0.12.0
trl>=0.8.6 trl>=0.8.6,<=0.9.6
gradio>=4.0.0 gradio>=4.0.0
pandas>=2.0.0 pandas>=2.0.0
scipy scipy

View File

@@ -27,7 +27,7 @@ from llamafactory.chat import ChatModel
def calculate_flops( def calculate_flops(
model_name_or_path: str, model_name_or_path: str,
batch_size: int = 1, batch_size: int = 1,
seq_length: int = 256, seq_length: int = 512,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r""" r"""
@@ -36,9 +36,11 @@ def calculate_flops(
""" """
with get_accelerator().device(0): with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()} input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True) flops, macs, params = get_model_profile(
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
)
print("FLOPs:", flops) print("FLOPs:", flops)
print("MACs:", macs) print("MACs:", macs)
print("Params:", params) print("Params:", params)

View File

@@ -25,7 +25,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@@ -39,16 +39,17 @@ def calculate_lr(
model_name_or_path: str, model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: Literal["pt", "sft"] = "sft", stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en", dataset: str = "alpaca_en_demo",
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate, is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False, packing: bool = False,
): ):
r""" r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
""" """
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
@@ -66,7 +67,8 @@ def calculate_lr(
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft": elif stage == "sft":
@@ -84,7 +86,7 @@ def calculate_lr(
valid_ratio = valid_tokens / total_tokens valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size) lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr lr = lr / 6.0 if is_mistral_or_gemma else lr
print( print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format( "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len lr, valid_ratio * 100, batch_valid_len

164
scripts/cal_mfu.py Normal file
View File

@@ -0,0 +1,164 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import fire
import torch
import torch.distributed as dist
from transformers import AutoConfig
from llamafactory.train.tuner import run_exp
BASE = 2 # gemm (add + mul)
def compute_model_flops(
model_name_or_path: str,
total_batch_size: int,
seq_length: int,
include_backward: bool = True,
include_recompute: bool = False,
include_flashattn: bool = False,
) -> int:
r"""
Calculates the FLOPs of model per forward/backward pass.
"""
config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None)
intermediate_size = getattr(config, "intermediate_size", None)
num_attention_heads = getattr(config, "num_attention_heads", None)
num_key_value_heads = getattr(config, "num_key_value_heads", None)
num_hidden_layers = getattr(config, "num_hidden_layers", None)
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
# mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size
o_flops_per_token = BASE * hidden_size * hidden_size
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module
embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False:
embedding_flops *= 2
non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
non_embedding_coeff, embedding_coeff = 1, 1
if include_backward:
non_embedding_coeff += 2
embedding_coeff += 2
if include_recompute:
non_embedding_coeff += 1
total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
if include_flashattn:
total_flops += sdpa_flops
return total_flops
def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * world_size
elif "V100" in device_name:
return 125 * 1e12 * world_size
elif "4090" in device_name:
return 98 * 1e12 * world_size
else:
raise NotImplementedError("Device not supported: {}.".format(device_name))
def calculate_mfu(
model_name_or_path: str,
batch_size: int = 1,
seq_length: int = 1024,
num_steps: int = 100,
finetuning_type: str = "lora",
flash_attn: str = "auto",
deepspeed_stage: int = 0,
disable_gc: bool = False,
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {
"model_name_or_path": model_name_or_path,
"flash_attn": flash_attn,
"disable_gradient_checkpointing": disable_gc,
"enable_liger_kernel": liger_kernel,
"use_unsloth_gc": unsloth_gc,
"stage": "pt",
"do_train": True,
"finetuning_type": finetuning_type,
"dataset": "c4_demo",
"cutoff_len": seq_length,
"output_dir": os.path.join("saves", "test_mfu"),
"logging_strategy": "no",
"save_strategy": "no",
"save_only_model": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": num_steps,
"bf16": True,
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print("MFU: {:.2f}%".format(mfu_value * 100))
if __name__ == "__main__":
fire.Fire(calculate_mfu)

View File

@@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer from llamafactory.model import load_model, load_tokenizer
@@ -55,12 +55,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
return super().__call__(chosen_features) return super().__call__(chosen_features)
def cal_ppl( def calculate_ppl(
model_name_or_path: str, model_name_or_path: str,
save_name: str, save_name: str,
batch_size: int = 4, batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft", stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en", dataset: str = "alpaca_en_demo",
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 1024, cutoff_len: int = 1024,
@@ -69,7 +69,7 @@ def cal_ppl(
): ):
r""" r"""
Calculates the ppl on the dataset of the pre-trained models. Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
""" """
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict( dict(
@@ -88,7 +88,8 @@ def cal_ppl(
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False) model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
@@ -129,4 +130,4 @@ def cal_ppl(
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(cal_ppl) fire.Fire(calculate_ppl)

View File

@@ -18,21 +18,21 @@ from collections import defaultdict
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
def length_cdf( def length_cdf(
model_name_or_path: str, model_name_or_path: str,
dataset: str = "alpaca_en", dataset: str = "alpaca_en_demo",
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r""" r"""
Calculates the distribution of the input lengths in the dataset. Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
""" """
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
@@ -48,7 +48,8 @@ def length_cdf(
) )
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset) total_num = len(trainset)
length_dict = defaultdict(int) length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]): for sample in tqdm(trainset["input_ids"]):

View File

@@ -19,7 +19,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import fire import fire
import torch import torch
@@ -47,8 +47,8 @@ def block_expansion(
model_name_or_path: str, model_name_or_path: str,
output_dir: str, output_dir: str,
num_expand: int, num_expand: int,
shard_size: Optional[str] = "2GB", shard_size: str = "2GB",
save_safetensors: Optional[bool] = False, save_safetensors: bool = True,
): ):
r""" r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.

View File

@@ -16,7 +16,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional from typing import Any, Dict
import fire import fire
import torch import torch
@@ -86,7 +86,10 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2( def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = True,
): ):
r""" r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B. Converts the Baichuan2-7B model in the same format as LLaMA2-7B.

View File

@@ -16,7 +16,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional from typing import Any, Dict
import fire import fire
import torch import torch
@@ -139,7 +139,10 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen( def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = False,
): ):
r""" r"""
Converts the Qwen models in the same format as LLaMA2. Converts the Qwen models in the same format as LLaMA2.

View File

@@ -67,7 +67,7 @@ def quantize_loftq(
loftq_dir = os.path.join(output_dir, "loftq_init") loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model # Save LoftQ model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir)) print("Adapter weights saved in {}".format(loftq_dir))

View File

@@ -31,7 +31,7 @@ if TYPE_CHECKING:
def quantize_pissa( def quantize_pissa(
model_name_or_path: str, model_name_or_path: str,
output_dir: str, output_dir: str,
pissa_iter: int = 4, pissa_iter: int = 16,
lora_alpha: int = None, lora_alpha: int = None,
lora_rank: int = 16, lora_rank: int = 16,
lora_dropout: float = 0, lora_dropout: float = 0,
@@ -62,6 +62,7 @@ def quantize_pissa(
pissa_dir = os.path.join(output_dir, "pissa_init") pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model # Save PiSSA model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir)) print("Adapter weights saved in {}".format(pissa_dir))

View File

@@ -14,11 +14,12 @@
import os import os
import re import re
from typing import List
from setuptools import find_packages, setup from setuptools import find_packages, setup
def get_version(): def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
@@ -26,27 +27,37 @@ def get_version():
return version return version
def get_requires(): def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f: with open("requirements.txt", "r", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
def get_console_scripts() -> List[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0"], "deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
"liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],
"eetq": ["eetq"], "eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3"], "vllm": ["vllm>=0.4.3,<=0.6.0"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"dev": ["ruff", "pytest"], "dev": ["ruff", "pytest"],
@@ -70,7 +81,7 @@ def main():
python_requires=">=3.8.0", python_requires=">=3.8.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]}, entry_points={"console_scripts": get_console_scripts()},
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@@ -20,22 +20,27 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2 transformers>=4.41.2,<=4.45.0
datasets>=2.16.0 datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1 accelerate>=0.30.1,<=0.33.0
peft>=0.11.1 peft>=0.11.1,<=0.12.0
trl>=0.8.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.42.4 transformers>=4.41.2,<=4.45.0
packing: packing:
transformers>=4.41.2,<=4.42.4 transformers>=4.41.2,<=4.45.0
patcher:
transformers==4.41.2 (chatglm) Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
""" """
from .cli import VERSION from .extras.env import VERSION
__version__ = VERSION __version__ = VERSION

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial
from typing import Optional from typing import Optional
from typing_extensions import Annotated from typing_extensions import Annotated
@@ -50,14 +52,24 @@ if is_uvicorn_available():
import uvicorn import uvicorn
async def sweeper() -> None:
while True:
torch_gc()
await asyncio.sleep(300)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
asyncio.create_task(sweeper())
yield yield
torch_gc() torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI": def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan) root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=["*"],
@@ -65,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
api_key = os.environ.get("API_KEY") api_key = os.environ.get("API_KEY", None)
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@@ -79,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)], dependencies=[Depends(verify_api_key)],
) )
async def list_models(): async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post( @app.post(

View File

@@ -16,6 +16,7 @@ import base64
import io import io
import json import json
import os import os
import re
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
@@ -51,9 +52,8 @@ if is_requests_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@@ -69,7 +69,7 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0: if len(request.messages) == 0:
@@ -104,15 +104,14 @@ def _process_request(
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else: else:
image_url = input_item.image_url.url image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb") image_stream = open(image_url, "rb")
else: # web uri else: # web uri
image_path = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB") image = Image.open(image_stream).convert("RGB")
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})

View File

@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -35,6 +35,12 @@ class Response:
class BaseEngine(ABC): class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model: Union["PreTrainedModel", "AsyncLLMEngine"] model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
data_args: "DataArguments", data_args: "DataArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ... ) -> None:
r"""
Initializes an inference engine.
"""
...
@abstractmethod @abstractmethod
async def chat( async def chat(
@@ -56,9 +66,14 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ... ) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
...
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
@@ -66,13 +81,22 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ... ) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
...
@abstractmethod @abstractmethod
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ... ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
...

View File

@@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from ..data.mm_plugin import ImageInput, VideoInput
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -38,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel: class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface": if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
@@ -56,10 +64,16 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
)
return task.result() return task.result()
async def achat( async def achat(
@@ -67,20 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs) r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs) r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -93,10 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(
@@ -104,6 +130,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result() return task.result()
@@ -112,6 +141,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
@@ -29,12 +31,11 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model( self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab ) # must after fixing tokenizer to resize vocab
@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if ( mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
processor is not None if image is not None:
and image is not None mm_input_dict.update({"images": [image], "imglens": [1]})
and not hasattr(processor, "image_seq_length") if IMAGE_PLACEHOLDER not in messages[0]["content"]:
and template.image_token not in messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
pixel_values = None prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.mm_plugin.process_token_ids(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
) )
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool) attention_mask = torch.ones_like(inputs, dtype=torch.bool)
@@ -164,8 +164,10 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if pixel_values is not None: mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
gen_kwargs["pixel_values"] = pixel_values for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@@ -180,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@@ -215,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@@ -267,12 +271,14 @@ class HuggingfaceEngine(BaseEngine):
return scores return scores
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@@ -289,18 +295,21 @@ class HuggingfaceEngine(BaseEngine):
system, system,
tools, tools,
image, image,
video,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args) return await loop.run_in_executor(pool, self._chat, *input_args)
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@@ -317,6 +326,7 @@ class HuggingfaceEngine(BaseEngine):
system, system,
tools, tools,
image, image,
video,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
@@ -328,6 +338,7 @@ class HuggingfaceEngine(BaseEngine):
except StopAsyncIteration: except StopAsyncIteration:
break break
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],

View File

@@ -15,32 +15,31 @@
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1 from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if is_vllm_version_greater_than_0_5_1():
pass
elif is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from ..data.mm_plugin import ImageInput, VideoInput
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
engine_args = { engine_args = {
@@ -85,19 +84,11 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
if model_args.visual_inputs: if getattr(config, "is_yi_vl_derived_model", None):
image_size = config.vision_config.image_size import vllm.model_executor.models.llava
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
@@ -110,37 +101,18 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if image is not None:
if ( if IMAGE_PLACEHOLDER not in messages[0]["content"]:
self.processor is not None messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if is_vllm_version_greater_than_0_5_1():
multi_modal_data = {"image": pixel_values}
elif is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1 use_beam_search: bool = self.generating_args["num_beams"] > 1
@@ -185,6 +157,17 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True, skip_special_tokens=True,
) )
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image}
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params, sampling_params=sampling_params,
@@ -193,16 +176,18 @@ class VllmEngine(BaseEngine):
) )
return result_generator return result_generator
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@@ -219,21 +204,24 @@ class VllmEngine(BaseEngine):
return results return results
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text
yield delta_text yield delta_text
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],

View File

@@ -118,4 +118,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}".format(command)) raise NotImplementedError("Unknown command: {}.".format(command))

View File

@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask from .collator import (
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset from .data_utils import Role, split_dataset
from .loader import get_dataset from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
@@ -20,6 +25,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask", "SFTDataCollatorWith4DAttentionMask",
"Role", "Role",

View File

@@ -14,9 +14,7 @@
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from datasets import Features
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
@@ -27,88 +25,117 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__) logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: def _convert_images(
images: Sequence["ImageInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
outputs = [] if len(images) == 0:
if dataset_attr.load_from in ["script", "file"]: return None
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
return outputs images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i])
return images
def _convert_videos(
videos: Sequence["VideoInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if len(videos) == 0:
return None
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
return videos
def convert_alpaca( def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts alpaca format dataset to the standard format. Converts alpaca format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])): convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
prompt = [] output = {
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): "_prompt": prompt,
for old_prompt, old_response in examples[dataset_attr.history][i]: "_response": response,
prompt.append({"role": Role.USER.value, "content": old_prompt}) "_system": example[dataset_attr.system] if dataset_attr.system else "",
prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
content = [] "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
if dataset_attr.prompt and examples[dataset_attr.prompt][i]: }
content.append(examples[dataset_attr.prompt][i]) return output
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def convert_sharegpt( def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts sharegpt format dataset to the standard format. Converts sharegpt format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER.value, dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value, dataset_attr.assistant_tag: Role.ASSISTANT.value,
@@ -119,74 +146,79 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags) accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]): messages = example[dataset_attr.messages]
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: if (
system = messages[0][dataset_attr.content_tag] dataset_attr.system_tag
messages = messages[1:] and len(messages) != 0
else: and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
system = examples[dataset_attr.system][i] if dataset_attr.system else "" ):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
if len(messages) == 0: aligned_messages = []
continue broken_data = False
for turn_idx, message in enumerate(messages):
aligned_messages = [] if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
broken_data = False logger.warning("Invalid role tag in {}.".format(messages))
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example aligned_messages.append(
prompt = aligned_messages[:-1] {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
response = aligned_messages[-1:] )
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
response = [ dataset_attr.ranking and len(aligned_messages) % 2 == 0
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, ):
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, logger.warning("Invalid message count in {}.".format(messages))
] broken_data = True
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data: if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
logger.warning("Skipping this abnormal example.") prompt = aligned_messages[:-1]
continue response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
outputs["prompt"].append(prompt) prompt = aligned_messages
outputs["response"].append(response) response = [
outputs["system"].append(system) {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) ]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
return outputs if broken_data:
logger.warning("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def align_dataset( def align_dataset(
@@ -197,11 +229,12 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1) _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." _system: "..."
tools: "...", _tools: "...",
images: [], _images: [],
_videos: [],
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
@@ -209,19 +242,6 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
@@ -232,8 +252,7 @@ def align_dataset(
return dataset.map( return dataset.map(
convert_func, convert_func,
batched=True, batched=False,
remove_columns=column_names, remove_columns=column_names,
features=features,
**kwargs, **kwargs,
) )

View File

@@ -16,12 +16,18 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
if TYPE_CHECKING:
from transformers import ProcessorMixin
from .template import Template
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
@@ -62,7 +68,42 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
batch_images.extend(images)
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"]))
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs)
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for 4d attention mask. Data collator for 4d attention mask.
""" """
@@ -80,7 +121,7 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
""" """
@@ -99,20 +140,16 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)], "labels": feature["{}_labels".format(key)],
"images": feature["images"],
"videos": feature["videos"],
} }
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature) concatenated_features.append(target_feature)
return super().__call__(concatenated_features) return super().__call__(concatenated_features)
@dataclass @dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for KTO data. Data collator for KTO data.
""" """
@@ -126,19 +163,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids": feature["input_ids"], "input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"], "attention_mask": feature["attention_mask"],
"labels": feature["labels"], "labels": feature["labels"],
"images": feature["images"],
"videos": feature["videos"],
} }
kl_feature = { kl_feature = {
"input_ids": feature["kl_input_ids"], "input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"], "attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"], "labels": feature["kl_labels"],
"images": feature["images"],
"videos": feature["videos"],
} }
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature) target_features.append(target_feature)
kl_features.append(kl_feature) kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"]) kto_tags.append(feature["kto_tags"])
@@ -148,7 +182,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"] batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch: if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags) batch["kto_tags"] = torch.tensor(kto_tags)

View File

@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
@@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
def split_dataset( def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict": ) -> "DatasetDict":
r""" r"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional). Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
""" """
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)

View File

@@ -16,21 +16,36 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None tool_format: Optional[str] = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: ... def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
...
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
raise NotImplementedError raise NotImplementedError
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
if has_placeholder: if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.") raise ValueError("Empty formatter should not contain any placeholder.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
return self.slots return self.slots
@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
if not has_placeholder: if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.") raise ValueError("A placeholder is required in the string formatter.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
elements = [] elements = []
for slot in self.slots: for slot in self.slots:
@@ -81,13 +98,9 @@ class StringFormatter(Formatter):
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
functions: List[Tuple[str, str]] = [] functions: List[Tuple[str, str]] = []
@@ -119,22 +132,17 @@ class FunctionFormatter(Formatter):
@dataclass @dataclass
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.tool_utils = get_tool_utils(self.tool_format)
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
tools = json.loads(content) tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
return [""] return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: @override
return self._tool_extractor(content) def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)

View File

@@ -27,7 +27,6 @@ from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -49,6 +48,9 @@ def _load_single_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
@@ -118,7 +120,7 @@ def _load_single_dataset(
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
target_num -= len(indexes) target_num -= len(indexes)
if target_num > 0: if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num) expand_indexes = np.random.choice(len(dataset), target_num)
@@ -142,6 +144,9 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
"""
if dataset_names is None: if dataset_names is None:
return None return None
@@ -165,6 +170,9 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None: if dataset is None:
return None return None
@@ -180,7 +188,13 @@ def _get_preprocessed_dataset(
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
**kwargs,
)
if training_args.should_log: if training_args.should_log:
try: try:
@@ -196,6 +210,7 @@ def _get_preprocessed_dataset(
def get_dataset( def get_dataset(
template: "Template",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
@@ -203,10 +218,9 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> "DatasetModule":
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) r"""
if data_args.train_on_prompt and template.efficient_eos: Gets the train dataset and optionally gets the evaluation dataset.
raise ValueError("Current template does not support `train_on_prompt`.") """
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
@@ -217,6 +231,7 @@ def get_dataset(
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"] dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
@@ -270,6 +285,7 @@ def get_dataset(
dataset_module = {} dataset_module = {}
if "train" in dataset_dict: if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"] dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]

View File

@@ -0,0 +1,400 @@
from copy import deepcopy
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_pyav_available():
import av
if TYPE_CHECKING:
import torch
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
class EncodedImage(TypedDict):
path: Optional[str]
bytes: Optional[bytes]
ImageInput = Union[str, EncodedImage, ImageObject]
VideoInput = str
def _regularize_images(
images: Sequence["ImageInput"],
processor: "ProcessorMixin",
max_resolution: Optional[int] = None,
) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading, resizing and converting.
"""
if max_resolution is None:
max_resolution: int = getattr(processor, "image_resolution", 512)
results = []
for image in images:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, dict):
if image["bytes"] is not None:
image = Image.open(BytesIO(image["bytes"]))
else:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
if max(image.width, image.height) > max_resolution:
factor = max_resolution / max(image.width, image.height)
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
results.append(image)
return results
def _regularize_videos(
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
video_resolution: int = getattr(processor, "video_resolution", 128)
video_fps: float = getattr(processor, "video_fps", 1.0)
video_maxlen: int = getattr(processor, "video_maxlen", 64)
video_factor: int = getattr(processor, "video_factor", 1)
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = _regularize_images(frames, processor, video_resolution)
results.append(frames)
return results
def _get_mm_inputs(
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
input_dict = {"images": None} # default key
if len(images) != 0:
images = _regularize_images(images, processor)
input_dict["images"] = images
if len(videos) != 0:
videos = _regularize_videos(videos, processor)
input_dict["videos"] = videos
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
return image_processor(**input_dict, return_tensors="pt")
else:
return {}
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen")
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids
class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
def _validate_input(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
) -> None:
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(images, videos)
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(images, videos)
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
"""
self._validate_input(images, videos)
return {}
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels
return input_ids, labels
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = _get_mm_inputs(images, videos, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = _get_mm_inputs(images, videos, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
),
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
content = content.replace(
VIDEO_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
),
1,
)
num_video_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
PLUGINS = {
"base": BasePlugin,
"llava": LlavaPlugin,
"paligemma": PaliGemmaPlugin,
"qwen2_vl": Qwen2vlPlugin,
}
def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))
return plugin_class(image_token, video_token)

View File

@@ -43,6 +43,7 @@ class DatasetAttr:
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None tools: Optional[str] = None
images: Optional[str] = None images: Optional[str] = None
videos: Optional[str] = None
# rlhf columns # rlhf columns
chosen: Optional[str] = None chosen: Optional[str] = None
rejected: Optional[str] = None rejected: Optional[str] = None
@@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_attr.set_attr("num_samples", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"]) column_names.extend(["prompt", "query", "response", "history"])
else: else:

View File

@@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing: if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs): def __init__(self, data, **kwargs):
@@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor,
data_args=data_args, data_args=data_args,
) )
else: else:

View File

@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@@ -35,14 +37,13 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]], kl_response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
@@ -55,6 +56,8 @@ def _encode_feedback_example(
else: else:
kl_messages = prompt + [kl_response[1]] kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
@@ -62,15 +65,13 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len] response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len) kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
kl_prompt_ids = kl_prompt_ids[:kl_source_len] kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len] kl_response_ids = kl_response_ids[:kl_target_len]
@@ -78,7 +79,6 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag
@@ -88,39 +88,27 @@ def preprocess_feedback_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1] kl_response = examples["_response"][::-1]
model_inputs = { model_inputs = defaultdict(list)
"input_ids": [], for i in range(len(examples["_prompt"])):
"attention_mask": [], if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
"labels": [], logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
kl_response=kl_response[i], kl_response=kl_response[i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
) )
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
@@ -129,11 +117,8 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels) model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag) model_inputs["kto_tags"].append(kto_tag)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@@ -34,16 +36,15 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
prompt[0]["content"] = template.image_token + prompt[0]["content"] rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
@@ -51,13 +52,9 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id] chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) # consider the response is more important
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
) # consider the response is more important
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len] chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len] rejected_ids = rejected_ids[:target_len]
@@ -66,7 +63,6 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
@@ -76,36 +72,25 @@ def preprocess_pairwise_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = { model_inputs = defaultdict(list)
"chosen_input_ids": [], for i in range(len(examples["_prompt"])):
"chosen_attention_mask": [], if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
"chosen_labels": [], logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
) )
model_inputs["chosen_input_ids"].append(chosen_input_ids) model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
@@ -113,15 +98,8 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels) model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
return model_inputs return model_inputs

View File

@@ -27,16 +27,16 @@ if TYPE_CHECKING:
def preprocess_pretrain_dataset( def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing: if not data_args.packing:
if data_args.template == "gemma": if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples] text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True) result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
else: else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}

View File

@@ -13,20 +13,7 @@
# limitations under the License. # limitations under the License.
import bisect import bisect
from typing import TYPE_CHECKING, List, Sequence, Tuple from typing import List, Sequence, Tuple
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
@@ -61,23 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r""" r"""
Computes the real sequence length after truncation by the cutoff_len. Computes the real sequence length after truncation by the cutoff_len.

View File

@@ -17,13 +17,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@@ -35,47 +36,49 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
prompt[0]["content"] = template.image_token + prompt[0]["content"] input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0 total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= data_args.cutoff_len: if total_length >= cutoff_len:
break break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length) source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len] source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len] target_ids = target_ids[:target_len]
total_length += source_len + target_len total_length += source_len + target_len
if data_args.train_on_prompt: if train_on_prompt:
source_label = source_ids source_label = source_ids
elif turn_idx != 0 and template.efficient_eos: elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else: else:
source_label = [IGNORE_INDEX] * source_len source_label = [IGNORE_INDEX] * source_len
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1: if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len target_label = [IGNORE_INDEX] * target_len
else: else:
target_label = target_ids target_label = target_ids
input_ids += source_ids + target_ids if mask_history: # reversed sequences
labels += source_label + target_label input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if template.efficient_eos: if template.efficient_eos:
input_ids += [tokenizer.eos_token_id] input_ids += [tokenizer.eos_token_id]
@@ -90,37 +93,34 @@ def preprocess_supervised_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None: for i in range(len(examples["_prompt"])):
model_inputs["pixel_values"] = [] if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
if hasattr(processor, "image_seq_length"): # paligemma models logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
) )
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs return model_inputs
@@ -129,28 +129,34 @@ def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0 valid_num = 0
batch_input_ids, batch_labels = [], [] batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = [] lengths = []
length2indexes = defaultdict(list) length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=None, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
) )
length = len(input_ids) length = len(input_ids)
if length > data_args.cutoff_len: if length > data_args.cutoff_len:
@@ -160,16 +166,21 @@ def preprocess_packed_supervised_dataset(
length2indexes[length].append(valid_num) length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids) batch_input_ids.append(input_ids)
batch_labels.append(labels) batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1 valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index] packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing: if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else: else:
@@ -190,6 +201,8 @@ def preprocess_packed_supervised_dataset(
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
return model_inputs return model_inputs

View File

@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@@ -34,28 +36,25 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
return input_ids, labels return input_ids, labels
@@ -67,36 +66,31 @@ def preprocess_unsupervised_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None: for i in range(len(examples["_prompt"])):
model_inputs["pixel_values"] = [] if len(examples["_prompt"][i]) % 2 != 1:
if hasattr(processor, "image_seq_length"): # paligemma models logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
) )
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs return model_inputs

View File

@@ -15,15 +15,21 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -41,9 +47,9 @@ class Template:
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
image_token: str
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
self, self,
@@ -147,6 +153,7 @@ class Template:
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
@override
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@@ -190,7 +197,7 @@ class Llama2Template(Template):
return encoded_messages return encoded_messages
TEMPLATES: Dict[str, Template] = {} TEMPLATES: Dict[str, "Template"] = {}
def _register_template( def _register_template(
@@ -205,9 +212,9 @@ def _register_template(
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
@@ -254,9 +261,9 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter, format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words, stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
mm_plugin=mm_plugin,
) )
@@ -300,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = "" jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
@@ -310,14 +320,15 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += ( jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}" "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
) )
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if not isinstance(template, Llama2Template): if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}" jinja_template += "{% for message in loop_messages %}"
jinja_template += "{% set content = message['content'] %}" jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template): if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
@@ -338,23 +349,30 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
return jinja_template return jinja_template
def get_template_and_fix_tokenizer( def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
tokenizer: "PreTrainedTokenizer", r"""
name: Optional[str] = None, Gets chat template and fixes the tokenizer.
tool_format: Optional[str] = None, """
) -> Template: if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
if name is None: require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(name, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(name)) raise ValueError("Template {} does not exist.".format(data_args.template))
if tool_format is not None: if data_args.train_on_prompt and template.efficient_eos:
logger.info("Using tool format: {}.".format(tool_format)) raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:
@@ -549,6 +567,15 @@ _register_template(
) )
_register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
_register_template( _register_template(
name="dbrx", name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -578,6 +605,7 @@ _register_template(
_register_template( _register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
@@ -585,14 +613,14 @@ _register_template(
_register_template( _register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n"]), format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=( default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, " "You are an AI programming assistant, utilizing the DeepSeek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. " "developed by DeepSeek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, " "For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n" "and other non-computer science questions, you will refuse to answer.\n"
), ),
) )
@@ -714,6 +742,17 @@ _register_template(
) )
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@@ -758,6 +797,19 @@ _register_template(
) )
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@@ -781,6 +833,33 @@ _register_template(
) )
_register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
_register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template( _register_template(
name="solar", name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
@@ -878,6 +957,7 @@ _register_template(
), ),
stop_words=["###"], stop_words=["###"],
efficient_eos=True, efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
) )

View File

@@ -15,9 +15,12 @@
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
) )
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
@dataclass @dataclass
class ToolUtils(ABC): class ToolUtils(ABC):
@staticmethod """
@abstractmethod Base class for tool utilities.
def get_function_slots() -> SLOTS: ... """
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... def get_function_slots() -> SLOTS:
r"""
Gets a list of slots corresponding to a single function call.
"""
...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
"""
...
class DefaultToolUtils(ToolUtils): class DefaultToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"] return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content) action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"] return ["{{name}}\n{{arguments}}"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text) return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content: if "\n" not in content:
return content return content
@@ -138,3 +166,17 @@ class GLM4ToolUtils(ToolUtils):
return content return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))] return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
return tool_utils

View File

@@ -39,7 +39,7 @@
import json import json
import os import os
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
@@ -54,18 +54,22 @@ from ..model import load_model, load_tokenizer
from .template import get_eval_template from .template import get_eval_template
if TYPE_CHECKING:
from numpy.typing import NDArray
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
logits = self.model(**batch_input).logits logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1) lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
@@ -132,7 +136,7 @@ class Evaluator:
pbar.close() pbar.close()
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))

View File

@@ -47,6 +47,8 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>"
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml" LLAMABOARD_CONFIG = "llamaboard_config.yaml"
@@ -93,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>"
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
@@ -531,6 +535,10 @@ register_model_group(
"Gemma-1.1-7B-Chat": { "Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it", DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
}, },
"Gemma-2-2B": {
DownloadSource.DEFAULT: "google/gemma-2-2b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
},
"Gemma-2-9B": { "Gemma-2-9B": {
DownloadSource.DEFAULT: "google/gemma-2-9b", DownloadSource.DEFAULT: "google/gemma-2-9b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
@@ -539,6 +547,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b", DownloadSource.DEFAULT: "google/gemma-2-27b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
}, },
"Gemma-2-2B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
},
"Gemma-2-9B-Chat": { "Gemma-2-9B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
@@ -619,10 +631,22 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"InternLM2.5-1.8B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
},
"InternLM2.5-7B": { "InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b", DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
}, },
"InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
},
"InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
},
"InternLM2.5-7B-Chat": { "InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
@@ -631,6 +655,10 @@ register_model_group(
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
}, },
"InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
},
}, },
template="intern2", template="intern2",
) )
@@ -739,6 +767,37 @@ register_model_group(
) )
register_model_group(
models={
"LLaMA3.1-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
},
"LLaMA3.1-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
},
"LLaMA3.1-405B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
},
"LLaMA3.1-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
},
"LLaMA3.1-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
},
"LLaMA3.1-405B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
},
},
template="llama3",
)
register_model_group( register_model_group(
models={ models={
"LLaVA1.5-7B-Chat": { "LLaVA1.5-7B-Chat": {
@@ -748,7 +807,7 @@ register_model_group(
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
}, },
}, },
template="vicuna", template="llava",
vision=True, vision=True,
) )
@@ -768,6 +827,17 @@ register_model_group(
) )
register_model_group(
models={
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
},
},
template="cpm3",
)
register_model_group( register_model_group(
models={ models={
"Mistral-7B-v0.1": { "Mistral-7B-v0.1": {
@@ -791,6 +861,11 @@ register_model_group(
}, },
"Mistral-7B-v0.3-Chat": { "Mistral-7B-v0.3-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-7B-Instruct-v0.3",
},
"Mistral-Nemo-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
}, },
}, },
template="mistral", template="mistral",
@@ -888,27 +963,28 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"PaliGemma-3B-pt-224": { "PaliGemma-3B-pt-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224", DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
}, },
"PaliGemma-3B-pt-448": { "PaliGemma-3B-pt-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448", DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
}, },
"PaliGemma-3B-pt-896": { "PaliGemma-3B-pt-896-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896", DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
}, },
"PaliGemma-3B-mix-224": { "PaliGemma-3B-mix-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224", DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
}, },
"PaliGemma-3B-mix-448": { "PaliGemma-3B-mix-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448", DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
}, },
}, },
template="paligemma",
vision=True, vision=True,
) )
@@ -1202,6 +1278,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B", DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B", DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
}, },
"Qwen2-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B",
},
"Qwen2-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B",
},
"Qwen2-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B",
},
"Qwen2-0.5B-Chat": { "Qwen2-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
@@ -1222,6 +1310,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
}, },
"Qwen2-Math-1.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B-Instruct",
},
"Qwen2-Math-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B-Instruct",
},
"Qwen2-Math-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B-Instruct",
},
"Qwen2-0.5B-int8-Chat": { "Qwen2-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
@@ -1263,6 +1363,38 @@ register_model_group(
) )
register_model_group(
models={
"Qwen2VL-2B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
},
"Qwen2VL-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
},
"Qwen2VL-2B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
},
"Qwen2VL-2B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
},
"Qwen2VL-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
},
"Qwen2VL-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
},
},
template="qwen2_vl",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B": { "SOLAR-10.7B": {
@@ -1534,6 +1666,22 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
}, },
"Yi-Coder-1.5B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B",
},
"Yi-Coder-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B",
},
"Yi-Coder-1.5B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B-Chat",
},
"Yi-Coder-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B-Chat",
},
}, },
template="yi", template="yi",
) )

View File

@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.8.3" VERSION = "0.9.0"
def print_env() -> None: def print_env() -> None:

View File

@@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,14 +18,21 @@
import logging import logging
import os import os
import sys import sys
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from .constants import RUNNING_LOG from .constants import RUNNING_LOG
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
r""" r"""
Logger handler used in Web UI. Redirects the logging output to the logging file for LLaMA Board.
""" """
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
@@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler):
return super().close() return super().close()
def get_logger(name: str) -> logging.Logger: def _get_default_logging_level() -> "logging._Level":
r""" r"""
Gets a standard logger with a stream hander to stdout. Returns the default logging level.
""" """
formatter = logging.Formatter( env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" if env_level_str:
) if env_level_str.upper() in logging._nameToLevel:
handler = logging.StreamHandler(sys.stdout) return logging._nameToLevel[env_level_str.upper()]
handler.setFormatter(formatter) else:
raise ValueError("Unknown logging level: {}.".format(env_level_str))
logger = logging.getLogger(name) return _default_log_level
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
def reset_logging() -> None: def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
r""" r"""
Removes basic config of root logger. (unused in script) Configures root logger using a stdout stream handler with an explicit format.
""" """
root = logging.getLogger() global _default_handler
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters)) with _thread_lock:
if _default_handler:
return
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "logging.Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)

View File

@@ -37,7 +37,7 @@ from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try: try:
_is_bf16_available = is_torch_bf16_gpu_available() _is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())
except Exception: except Exception:
_is_bf16_available = False _is_bf16_available = False
@@ -79,11 +79,11 @@ def check_dependencies() -> None:
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2") require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0")
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
@@ -137,7 +137,9 @@ def get_device_count() -> int:
r""" r"""
Gets the number of available GPU or NPU devices. Gets the number of available GPU or NPU devices.
""" """
if is_torch_npu_available(): if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count() return torch.npu.device_count()
elif is_torch_cuda_available(): elif is_torch_cuda_available():
return torch.cuda.device_count() return torch.cuda.device_count()
@@ -154,6 +156,18 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
return 0, 0
def has_tokenized_data(path: "os.PathLike") -> bool: def has_tokenized_data(path: "os.PathLike") -> bool:
r""" r"""
Checks if the path has a tokenized dataset. Checks if the path has a tokenized dataset.
@@ -181,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu() inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@@ -192,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None: def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports

View File

@@ -38,6 +38,10 @@ def _get_package_version(name: str) -> "Version":
return version.parse("0.0.0") return version.parse("0.0.0")
def is_pyav_available():
return _is_package_available("av")
def is_fastapi_available(): def is_fastapi_available():
return _is_package_available("fastapi") return _is_package_available("fastapi")
@@ -70,19 +74,14 @@ def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
@lru_cache
def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0")
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")
def is_vllm_available(): def is_vllm_available():
return _is_package_available("vllm") return _is_package_available("vllm")
@lru_cache
def is_vllm_version_greater_than_0_5():
return _get_package_version("vllm") >= version.parse("0.5.0")
@lru_cache
def is_vllm_version_greater_than_0_5_1():
return _get_package_version("vllm") >= version.parse("0.5.1")

View File

@@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r""" r"""
Plots loss curves and saves the image. Plots loss curves and saves the image.
""" """

View File

@@ -73,6 +73,10 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}, metadata={"help": "Overwrite the cached training and evaluation sets."},
) )
preprocessing_batch_size: int = field(
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the pre-processing."}, metadata={"help": "The number of processes to use for the pre-processing."},
@@ -141,3 +145,6 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")

View File

@@ -326,6 +326,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
) )
use_adam_mini: bool = field(
default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
)
freeze_vision_tower: bool = field( freeze_vision_tower: bool = field(
default=True, default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},

View File

@@ -15,23 +15,141 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch
from typing_extensions import Self from typing_extensions import Self
if TYPE_CHECKING: @dataclass
import torch class QuantizationArguments:
r"""
Arguments pertaining to the quantization method.
"""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
@dataclass @dataclass
class ModelArguments: class ProcessorArguments:
r"""
Arguments pertaining to the image processor.
"""
image_resolution: int = field(
default=512,
metadata={"help": "Keeps the height or width of image below this resolution."},
)
video_resolution: int = field(
default=128,
metadata={"help": "Keeps the height or width of video below this resolution."},
)
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video inputs."},
)
video_maxlen: int = field(
default=64,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
@dataclass
class ExportArguments:
r"""
Arguments pertaining to the model export.
"""
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
@dataclass
class VllmArguments:
r"""
Arguments pertaining to the vLLM worker.
"""
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
@dataclass
class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments):
r""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
""" """
model_name_or_path: str = field( model_name_or_path: Optional[str] = field(
default=None,
metadata={ metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
}, },
@@ -77,26 +195,6 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."}, metadata={"help": "Whether or not to use memory-efficient model loading."},
) )
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
@@ -117,9 +215,13 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
) )
visual_inputs: bool = field( use_unsloth_gc: bool = field(
default=False, default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
)
enable_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
) )
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: Optional[float] = field(
default=None, default=None,
@@ -145,22 +247,6 @@ class ModelArguments:
default="huggingface", default="huggingface",
metadata={"help": "Backend engine used at inference."}, metadata={"help": "Backend engine used at inference."},
) )
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
offload_folder: str = field( offload_folder: str = field(
default="offload", default="offload",
metadata={"help": "Path to offload model weights."}, metadata={"help": "Path to offload model weights."},
@@ -181,59 +267,38 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
print_param_status: bool = field( print_param_status: bool = field(
default=False, default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
) )
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, Dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
)
model_max_length: Optional[int] = field(
default=None,
init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
)
block_diag_attn: bool = field(
default=False,
init=False,
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
)
def __post_init__(self): def __post_init__(self):
self.compute_dtype: Optional["torch.dtype"] = None if self.model_name_or_path is None:
self.device_map: Optional[Union[str, Dict[str, Any]]] = None raise ValueError("Please provide `model_name_or_path`.")
self.model_max_length: Optional[int] = None
self.block_diag_attn: bool = False
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
@@ -247,9 +312,13 @@ class ModelArguments:
return asdict(self) return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self: def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs) arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
new_arg = cls(**arg_dict) new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map new_arg.device_map = old_arg.device_map

View File

@@ -26,7 +26,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
@@ -116,11 +116,14 @@ def _check_extra_dependencies(
if model_args.use_unsloth: if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
if model_args.mixture_of_depths is not None: if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3") require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0")
if finetuning_args.use_galore: if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch") require_version("galore_torch", "To fix: pip install galore_torch")
@@ -128,6 +131,9 @@ def _check_extra_dependencies(
if finetuning_args.use_badam: if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
if finetuning_args.plot_loss: if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib") require_version("matplotlib", "To fix: pip install matplotlib")
@@ -163,11 +169,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage != "pt" and data_args.template is None: if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate: if finetuning_args.stage != "sft":
raise ValueError("`predict_with_generate` cannot be set as True except SFT.") if training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage != "sft" and data_args.neat_packing: if data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.") raise ValueError("`neat_packing` cannot be set as True except SFT.")
if data_args.train_on_prompt or data_args.mask_history:
raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
@@ -175,21 +185,18 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end: if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train: if finetuning_args.stage == "ppo":
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") if not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage == "ppo" and model_args.shift_attn: if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")
if ( if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
finetuning_args.stage == "ppo" raise ValueError("PPO only accepts wandb or tensorboard logger.")
and training_args.report_to
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
@@ -208,11 +215,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
): ):
raise ValueError("Please specify dataset for evaluation.") raise ValueError("Please specify dataset for evaluation.")
if training_args.predict_with_generate and data_args.eval_dataset is None: if training_args.predict_with_generate:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.") if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if training_args.predict_with_generate and finetuning_args.compute_accuracy: if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if training_args.do_train and model_args.quantization_device_map == "auto": if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.") raise ValueError("Cannot use device map for quantized models in training.")
@@ -221,7 +232,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.") raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16: if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available(): if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
raise ValueError("This device does not support `pure_bf16`.") raise ValueError("This device does not support `pure_bf16`.")
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
@@ -246,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
@@ -377,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.") raise ValueError("vLLM only accepts a single adapter. Merge them first.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from llamafactory.train.tuner import run_exp from llamafactory.train.tuner import run_exp # use absolute import
def launch(): def launch():

View File

@@ -24,6 +24,7 @@ from ..extras.logging import get_logger
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import get_forbidden_modules, patch_target_modules
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -37,7 +38,6 @@ logger = get_logger(__name__)
def _setup_full_tuning( def _setup_full_tuning(
model: "PreTrainedModel", model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
@@ -46,13 +46,7 @@ def _setup_full_tuning(
return return
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
forbidden_modules = set() forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules): if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32: if cast_trainable_params_to_fp32:
@@ -63,7 +57,6 @@ def _setup_full_tuning(
def _setup_freeze_tuning( def _setup_freeze_tuning(
model: "PreTrainedModel", model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
@@ -72,8 +65,8 @@ def _setup_freeze_tuning(
return return
logger.info("Fine-tuning method: Freeze") logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs: if hasattr(model.config, "text_config"): # composite models
config = model.config.text_config config = getattr(model.config, "text_config")
else: else:
config = model.config config = model.config
@@ -130,10 +123,7 @@ def _setup_freeze_tuning(
trainable_layers.append(module_name) trainable_layers.append(module_name)
forbidden_modules = set() forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules forbidden_module in name for forbidden_module in forbidden_modules
@@ -211,8 +201,7 @@ def _setup_lora_tuning(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower: target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
if ( if (
finetuning_args.use_dora finetuning_args.use_dora
@@ -303,9 +292,9 @@ def init_adapter(
cast_trainable_params_to_fp32 = True cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze": elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora": elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning( model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32

View File

@@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
@@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
@@ -93,17 +95,24 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
if model_args.visual_inputs: try:
try: processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) setattr(processor, "tokenizer", tokenizer)
setattr(processor, "tokenizer", tokenizer) setattr(processor, "image_seqlen", get_image_seqlen(config))
except Exception: setattr(processor, "image_resolution", model_args.image_resolution)
raise ValueError( setattr(processor, "video_resolution", model_args.video_resolution)
"This multimodal LLM is not supported.\n" setattr(processor, "video_fps", model_args.video_fps)
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n" setattr(processor, "video_maxlen", model_args.video_maxlen)
"Download Yi-VL models from: https://huggingface.co/BUAADreamer" if getattr(config, "model_type", None) == "qwen2_vl":
) setattr(processor, "video_factor", 2)
else: else:
setattr(processor, "video_factor", 1)
except Exception:
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if "Processor" not in processor.__class__.__name__:
processor = None processor = None
return {"tokenizer": tokenizer, "processor": processor} return {"tokenizer": tokenizer, "processor": processor}
@@ -145,12 +154,16 @@ def load_model(
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs) model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
elif model_args.train_from_scratch:
model = AutoModelForCausalLM.from_config(config)
else: else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs) if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
else:
model = load_class.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert": if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args) model = convert_pretrained_model_to_mod(model, config, model_args)

View File

@@ -36,7 +36,7 @@ def configure_attn_implementation(
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available(): if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2" model_args.flash_attn = "fa2"
else: else:

View File

@@ -1,8 +1,10 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers and PEFT library. # This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
# and the Unsloth library.
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -17,9 +19,9 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from functools import partial from functools import partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
@@ -36,8 +38,70 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(
ctx: "torch.autograd.Function",
forward_function: "torch.Module",
hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any],
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
return UnslothGradientCheckpointing.apply
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func
def _gradient_checkpointing_enable( def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None: ) -> None:
r""" r"""
Activates gradient checkpointing for the current model. Activates gradient checkpointing for the current model.
@@ -52,24 +116,18 @@ def _gradient_checkpointing_enable(
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) if use_unsloth_gc:
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
def custom_gradient_checkpointing_func(func, *args, **kwargs): else:
module: "torch.nn.Module" = func.__self__ gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads() self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def _fp32_forward_post_hook( def _fp32_forward_post_hook(
@@ -97,7 +155,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
else: else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339 # According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) gradient_checkpointing_enable = partial(
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")

View File

@@ -0,0 +1,53 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.enable_liger_kernel:
return
model_type = getattr(config, "model_type", None)
if model_type == "gemma":
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif model_type == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
else:
logger.warning("Current model does not support liger kernel.")
return
apply_liger_kernel()
logger.info("Liger kernel has been applied to the model.")

View File

@@ -35,6 +35,7 @@ from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__)
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward( def llama_attention_forward(
self: "LlamaAttention", self: "LlamaAttention",
hidden_states: torch.Tensor, hidden_states: "torch.Tensor",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: "torch.Tensor" = self.q_proj(hidden_states)
@@ -68,7 +70,11 @@ def llama_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids) if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@@ -130,14 +136,15 @@ def llama_attention_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward( def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2", self: "LlamaFlashAttention2",
hidden_states: torch.Tensor, hidden_states: "torch.Tensor",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False output_attentions = False
@@ -151,7 +158,11 @@ def llama_flash_attention_2_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids) if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@@ -198,9 +209,24 @@ def llama_flash_attention_2_forward(
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
attn_output: "torch.Tensor" = self._flash_attention_forward( if is_transformers_version_greater_than_4_43():
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate from transformers.modeling_flash_attention_utils import _flash_attention_forward
)
attn_output: "torch.Tensor" = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_states.size(1),
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
@@ -225,14 +251,15 @@ def llama_flash_attention_2_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward( def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention", self: "LlamaSdpaAttention",
hidden_states: torch.Tensor, hidden_states: "torch.Tensor",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
if output_attentions: if output_attentions:
transformers_logger.warning_once( transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@@ -258,7 +285,11 @@ def llama_sdpa_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids) if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@@ -322,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -28,17 +28,22 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
r""" r"""
Finds all available modules to apply lora or galore. Finds all available modules to apply lora or galore.
""" """
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"} forbidden_modules = {"lm_head"}
if model_type == "chatglm":
if model.config.model_type == "chatglm":
forbidden_modules.add("output_layer") forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2": elif model_type == "internlm2":
forbidden_modules.add("output") forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]: elif model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector") forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")
if freeze_vision_tower: if freeze_vision_tower:
forbidden_modules.add("vision_tower") if model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():

View File

@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
return return
if getattr(model.config, "model_type", None) == "dbrx": model_type = getattr(model.config, "model_type", None)
if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN from transformers.models.dbrx.modeling_dbrx import DbrxFFN
_set_z3_leaf_modules(model, [DbrxFFN]) _set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba": if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
_set_z3_leaf_modules(model, [JambaSparseMoeBlock]) _set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe": if model_type == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral": if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe": if model_type == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None: if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]: if model_type in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek": elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "jetmoe": elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable) setattr(config, "output_router_logits", is_trainable)

View File

@@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
if model_type == "cohere": if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon": elif model_type == "falcon":

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch import torch
import transformers.models import transformers.models
@@ -28,7 +28,7 @@ from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -80,24 +80,98 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
self.act = ACT2FN[projector_hidden_act] self.act = ACT2FN[projector_hidden_act]
def autocast_projector_dtype( def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector" r"""
) -> None: Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def _mm_projector_forward_post_hook( def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor": ) -> "torch.Tensor":
return output.to(model_args.compute_dtype) return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None): if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "paligemma"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else:
return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None: def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models r"""
Patches VLMs before loading them.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava": # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "paligemma"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
elif model_type == "qwen2_vl":
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("visual")
if finetuning_args.train_mm_proj_only:
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif model_type == "qwen2_vl": # variable length
image_seqlen = -1
return image_seqlen
def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]:
r"""
Freezes vision tower for VLM LoRA tuning.
"""
model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "paligemma"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
else:
if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules

View File

@@ -21,13 +21,13 @@ from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer from .model_utils.embedding import resize_embedding_layer
from .model_utils.liger_kernel import configure_liger_kernel
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing from .model_utils.packing import configure_packing
@@ -71,6 +71,7 @@ def patch_config(
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
configure_liger_kernel(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
@@ -89,9 +90,6 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "chatglm":
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
@@ -133,11 +131,9 @@ def patch_model(
if model_args.resize_vocab: if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer) resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable: if is_trainable:
prepare_model_for_training(model, model_args) prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model) add_z3_leaf_module(model)
if not model_args.use_unsloth: if not model_args.use_unsloth:

View File

@@ -32,9 +32,11 @@ from transformers.utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
is_safetensors_available, is_safetensors_available,
) )
from typing_extensions import override
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger from ..extras.logging import LoggerHandler, get_logger
from ..extras.misc import get_peak_memory
if is_safetensors_available(): if is_safetensors_available():
@@ -73,8 +75,8 @@ def fix_valuehead_checkpoint(
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {} os.remove(path_to_checkpoint)
v_head_state_dict = {} decoder_state_dict, v_head_state_dict = {}, {}
for name, param in state_dict.items(): for name, param in state_dict.items():
if name.startswith("v_head."): if name.startswith("v_head."):
v_head_state_dict[name] = param v_head_state_dict[name] = param
@@ -90,43 +92,52 @@ def fix_valuehead_checkpoint(
else: else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
os.remove(path_to_checkpoint)
logger.info("Value head model saved at: {}".format(output_dir)) logger.info("Value head model saved at: {}".format(output_dir))
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after a checkpoint save. Event called after a checkpoint save.
""" """
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
) )
class SaveProcessorCallback(TrainerCallback): class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""
def __init__(self, processor: "ProcessorMixin") -> None: def __init__(self, processor: "ProcessorMixin") -> None:
r"""
Initializes a callback for saving the processor.
"""
self.processor = processor self.processor = processor
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save: if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir) getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback): class PissaConvertCallback(TrainerCallback):
r""" r"""
Initializes a callback for converting the PiSSA adapter to a normal one. A callback for converting the PiSSA adapter to a normal one.
""" """
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
@@ -141,10 +152,8 @@ class PissaConvertCallback(TrainerCallback):
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save: if args.should_save:
model = kwargs.pop("model") model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
@@ -162,29 +171,32 @@ class PissaConvertCallback(TrainerCallback):
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained( model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
) ) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default") model.set_adapter("default")
model.delete_adapter("pissa_init") if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""
def __init__(self) -> None: def __init__(self) -> None:
r""" # Progress
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self.start_time = 0 self.start_time = 0
self.cur_steps = 0 self.cur_steps = 0
self.max_steps = 0 self.max_steps = 0
self.elapsed_time = "" self.elapsed_time = ""
self.remaining_time = "" self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """ # Status
self.aborted = False self.aborted = False
self.do_train = False self.do_train = False
""" Web UI """ # Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
@@ -224,10 +236,8 @@ class LogCallback(TrainerCallback):
self.thread_pool.shutdown(wait=True) self.thread_pool.shutdown(wait=True)
self.thread_pool = None self.thread_pool = None
@override
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if ( if (
args.should_save args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
@@ -236,55 +246,41 @@ class LogCallback(TrainerCallback):
logger.warning("Previous trainer log in this folder will be deleted.") logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save: if args.should_save:
self.do_train = True self.do_train = True
self._reset(max_steps=state.max_steps) self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir) self._create_thread_pool(output_dir=args.output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool() self._close_thread_pool()
@override
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted: if self.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted: if self.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
@override
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train: if not self.do_train:
self._close_thread_pool() self._close_thread_pool()
@override
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train: if not self.do_train:
self._close_thread_pool() self._close_thread_pool()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save: if not args.should_save:
return return
@@ -302,26 +298,31 @@ class LogCallback(TrainerCallback):
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
) )
if state.num_input_tokens_seen:
logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
logs["total_tokens"] = state.num_input_tokens_seen
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
logs = {k: v for k, v in logs.items() if v is not None} logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info( logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"] logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
) )
) )
if self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)
@override
def on_prediction_step( def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
): ):
r"""
Event called after a prediction step.
"""
if self.do_train: if self.do_train:
return return

View File

@@ -26,10 +26,11 @@ import torch.nn.functional as F
from transformers import Trainer from transformers import Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
return losses, chosen_rewards, rejected_rewards return losses, chosen_rewards, rejected_rewards
@override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -176,7 +180,6 @@ class CustomDPOTrainer(DPOTrainer):
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length all_logps = all_logps / valid_length
@@ -187,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_length, _ = valid_length.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
@override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
@@ -208,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
return reference_chosen_logps, reference_rejected_logps return reference_chosen_logps, reference_rejected_logps
@override
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",

View File

@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
@@ -41,13 +41,15 @@ def run_dpo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, template=template,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
) )
# Create reference model # Create reference model
@@ -60,7 +62,7 @@ def run_dpo(
ref_model = None ref_model = None
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(

View File

@@ -25,10 +25,11 @@ import torch
from transformers import Trainer from transformers import Trainer
from trl import KTOTrainer from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r""" r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
""" """
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor"]:
@@ -127,17 +132,20 @@ class CustomKTOTrainer(KTOTrainer):
"input_ids": batch["{}input_ids".format(prefix)], "input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)], "attention_mask": batch["{}attention_mask".format(prefix)],
} }
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "{}token_type_ids".format(prefix) in batch: if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length return logps, logps / valid_length
@override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -153,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
chosen_logps_avg = target_logps_avg[batch["kto_tags"]] chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
@override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -173,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
@override
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",

View File

@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import KTODataCollatorWithPadding, get_dataset from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
@@ -41,13 +41,15 @@ def run_kto(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="kto", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = KTODataCollatorWithPadding( data_collator = KTODataCollatorWithPadding(
tokenizer=tokenizer, template=template,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
) )
# Create reference model # Create reference model
@@ -57,7 +59,7 @@ def run_kto(
ref_model = create_ref_model(model_args, finetuning_args) ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomKTOTrainer( trainer = CustomKTOTrainer(

View File

@@ -35,11 +35,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -133,6 +134,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
ref_model=ref_model, ref_model=ref_model,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset=train_dataset, dataset=train_dataset,
optimizer=optimizer,
data_collator=data_collator, data_collator=data_collator,
lr_scheduler=scheduler, lr_scheduler=scheduler,
) )
@@ -297,13 +299,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.callback_handler.on_train_end(self.args, self.state, self.control) self.callback_handler.on_train_end(self.args, self.state, self.control)
@override
def create_optimizer( def create_optimizer(
self, self,
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
optimizer = create_custom_optimzer(model, training_args, finetuning_args) optimizer = create_custom_optimizer(model, training_args, finetuning_args)
if optimizer is None: if optimizer is None:
decay_params, nodecay_params = [], [] decay_params, nodecay_params = [], []
decay_param_names = self.get_decay_parameter_names(model) decay_param_names = self.get_decay_parameter_names(model)
@@ -323,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return optimizer return optimizer
@override
def create_scheduler( def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer" self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
@@ -409,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.float().detach() # use fp32 type return rewards.float().detach() # use fp32 type
@override
@PPODecorators.empty_device_cache() @PPODecorators.empty_device_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,
@@ -477,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
torch.cat(all_masks)[:, :-1], torch.cat(all_masks)[:, :-1],
) )
@override
def save_model(self, output_dir: Optional[str] = None) -> None: def save_model(self, output_dir: Optional[str] = None) -> None:
r""" r"""
Saves model checkpoint. Saves model checkpoint.

View File

@@ -17,9 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from ...data import get_dataset
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint from ..callbacks import fix_valuehead_checkpoint
@@ -43,11 +41,12 @@ def run_ppo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module)
# Create reference model and reward model # Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)

View File

@@ -16,10 +16,11 @@ from types import MethodType
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":

Some files were not shown because too many files have changed in this diff Show More