159 Commits

Author SHA1 Message Date
hiyouga
dfff411e1a release v0.9.0 (real)
Former-commit-id: 8ff781c8ae5654680f738f69a6db9d7b95d76baf
2024-09-09 01:00:25 +08:00
hiyouga
e20baa4218 fix constants
Former-commit-id: fce6671d2764d7a2b77c44401fc5582c7cbb77aa
2024-09-08 23:52:30 +08:00
hiyouga
d1ab9b501a release v0.9.0
Former-commit-id: 594c450f648ad326ef39c0f4d70d67cda5f36159
2024-09-08 23:43:35 +08:00
hiyouga
3cbc9109ea tiny fix
Former-commit-id: 76177039c8f9ef5a63724a339dae6195d89fa215
2024-09-08 23:18:08 +08:00
hiyouga
3259397f89 update scripts
Former-commit-id: 51d087cbc14bf3c7dfa06b8b66052cd80a6081be
2024-09-08 14:17:41 +08:00
hiyouga
eb5af3d90b support vllm 0.6.0
Former-commit-id: e39470ec51a9c74ad901871eb816df10e851f351
2024-09-08 02:26:20 +08:00
hiyouga
b6810b209a fix test case
Former-commit-id: b075b2971c6acb2c6039b36420a296f1f4e1b91b
2024-09-08 01:50:51 +08:00
hiyouga
158e0e1f63 add test case
Former-commit-id: c452d65e1551074dddd1d87517c0d44dc014c6aa
2024-09-08 01:40:49 +08:00
hiyouga
294a103ead support activation offloading via unsloth gc
Former-commit-id: d3d0dd0feba3ca6f0ae970d5856bec989d26ef67
2024-09-08 01:22:19 +08:00
hiyouga
7f71276ad8 add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
2024-09-08 00:56:56 +08:00
hoshi-hiyouga
93d4570a59 Merge pull request #5388 from yzoaim/cal_mfu_update
update cal_mfu.py

Former-commit-id: fe5eac2cb6a4646b653232d7c68d535105b60f3a
2024-09-08 00:49:28 +08:00
hoshi-hiyouga
527ba2eb2e fix
Former-commit-id: 53a74cbc3afec58b36c2265e080061bcdf702f98
2024-09-08 00:41:45 +08:00
hoshi-hiyouga
3021b31cf3 Update cal_mfu.py
Former-commit-id: 0c391b2e59943b0ca9dd4e8561398e7c856a4b29
2024-09-08 00:39:48 +08:00
-.-
9f2427907e update cal_mfu.py
Former-commit-id: 1cdbb4c774d463969c6be14fb08d92c7a0bdb565
2024-09-07 23:21:35 +08:00
hoshi-hiyouga
570ce100c1 fix #5384
Former-commit-id: 2e86c54f381f7403c30ba78d2acf5003aab6e049
2024-09-07 01:21:14 +08:00
hiyouga
27547355e6 tiny fix
Former-commit-id: c0e9c0484dae6db93cef5048bad827ff22b1986a
2024-09-05 23:41:16 +08:00
hiyouga
c5ef52a67a fix ci
Former-commit-id: b5ffca5a190f3aed8ba8c49bd8cf3239fb787bf5
2024-09-05 22:39:47 +08:00
hiyouga
b48b47d519 fix ci
Former-commit-id: cf0758b03e9b8b4931ba790a9726b8256ee4286c
2024-09-05 22:27:48 +08:00
hiyouga
9bdba2f6a8 add e2e tests
Former-commit-id: 0156a37450604641c4f5f9756ad84324698fc88c
2024-09-05 21:52:28 +08:00
hoshi-hiyouga
d6ce902d80 Merge pull request #5372 from LDLINGLINGLING/main
增加了对minicpm3.0的适配'

Former-commit-id: 2e3c221d9c87bd59f48648be8878b7b50347280f
2024-09-05 21:35:42 +08:00
liudan
ce6dcf3600 根据代码规范修改了代码
Former-commit-id: fe5351980b42e0e38175b0da2401a61b3807fa7c
2024-09-05 20:17:55 +08:00
hoshi-hiyouga
e7f92d16d8 fix #5366
Former-commit-id: b0a4964846dd5be7aa2c54d43f28ba62985587f1
2024-09-05 18:08:09 +08:00
hiyouga
abd26f5f67 update data readme
Former-commit-id: 0af5f054b7b8da8b39eb44b1dfa76050f0c45667
2024-09-05 04:44:49 +08:00
hiyouga
4d35ace75e update data readme
Former-commit-id: 81adb153b7d0b30e6cd50c9bf4ca1ccf17458611
2024-09-05 04:25:27 +08:00
hiyouga
72222d1598 support Yi-Coder models
Former-commit-id: ea3f1659e70541c4fa8b7079a0a8c94fce9a41c8
2024-09-05 03:12:24 +08:00
hiyouga
26d914b8fc fix ci
Former-commit-id: 280c0f3f2cea4dfced797cc0e15f72b8b3a93542
2024-09-05 03:02:59 +08:00
hiyouga
7b01c0676c fix ci
Former-commit-id: 7899b44b19c3d0a70706d987bb7d2e0e3536014b
2024-09-05 02:49:22 +08:00
hiyouga
571a9b8669 update ci
Former-commit-id: e24bf7345442701ca874d439f0ca3da49fa59a84
2024-09-05 02:26:10 +08:00
hoshi-hiyouga
ed35eb1e9e Merge pull request #5365 from hiyouga/video_finetuning
Support Qwen2-VL Fine-Tuning on Video Datasets

Former-commit-id: 178cc3fbc48bf2c68723b487681db04e660b12fa
2024-09-05 02:24:58 +08:00
hiyouga
d291e0d60d tiny fix
Former-commit-id: 9da6e084e1e5daf7403e7fabeaaec686167fb11f
2024-09-05 02:16:49 +08:00
hiyouga
1874d579c5 video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
2024-09-05 02:04:17 +08:00
liudan
c692339020 增加了对minicpm3.0的适配'
Former-commit-id: 4ad3a761af2452ef3f6c61190b7e47c9ea5227b9
2024-09-04 23:10:05 +08:00
hiyouga
2c1eef34cb fix test
Former-commit-id: 553a83aff9f9da35c9a0eca81f7d2b0bf2adf6ff
2024-09-04 22:38:26 +08:00
hiyouga
af178cbcd1 update get template
Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
2024-09-04 22:36:20 +08:00
hoshi-hiyouga
5d85be31ca Merge pull request #5323 from naem1023/feat/add-dataset-map-batch-size-argument
Add batch size of map function in the preprocessed dataset

Former-commit-id: c3428c5807500d087cdee4386798e10e39c9cf30
2024-09-04 22:09:36 +08:00
hoshi-hiyouga
372b71c847 fix #5228
Former-commit-id: 0d332ca8d0987c0331361934ab110fafa6402a7e
2024-09-04 19:10:30 +08:00
hiyouga
41a9c415e1 fix #5252
Former-commit-id: 73f30b4dfffb260e24f9e2332617b8ca2c249ed5
2024-09-04 03:17:54 +08:00
hiyouga
915e32a5f8 add vl_feedback dataset
Former-commit-id: 6ff34ad2db383b5fbd51008bcc5eec880658811e
2024-09-04 03:13:03 +08:00
hiyouga
f4dd429cbf fix #5344
Former-commit-id: 9d445c0b5be5ccc0e6d1979e76a869ddf92d9534
2024-09-04 03:06:06 +08:00
hoshi-hiyouga
7435cde2ef Merge pull request #5346 from hiyouga/lazy_image
[exp] Lazyload for multimodal inputs

Former-commit-id: 4bbd721361a8c5888b28f5fcdcbb2a4ad2305445
2024-09-04 03:00:53 +08:00
hiyouga
7056087e92 lazy image load
Former-commit-id: cdd733b575411e003bc5ffd6560dd8eff8aa09cf
2024-09-04 02:27:08 +08:00
hiyouga
fed7ae5661 fix #5334
Former-commit-id: a5ea0f83f00c81d128a1f50ce244866ce38ee15f
2024-09-03 19:09:42 +08:00
hiyouga
5019c6148b fix #5338
Former-commit-id: a66ddfea218feefde50fa097d20b4bcbe89ab791
2024-09-03 17:45:17 +08:00
hiyouga
2e1396cd6b lint
Former-commit-id: d821d933e6cb982d648a69f85f6ad01d0560ed70
2024-09-03 00:46:25 +08:00
hiyouga
b5e9df5df8 fix #5324
Former-commit-id: f7aa06c9c0b18c28419ea5792410915d3f322cbf
2024-09-02 23:56:21 +08:00
naem1023
3622856994 feat: add batch size of map function in the preprocessed dataset
Former-commit-id: 94b6cf06c2f84d0619b1a2dccaf8abb51de9951c
2024-09-02 13:52:47 +09:00
hoshi-hiyouga
7367c6ec21 fix trainer predict
Former-commit-id: 2790790cd26c6743105555a60523b89f367ebce3
2024-09-02 10:15:29 +08:00
hoshi-hiyouga
6579ec8c4c remove .cpu()
Former-commit-id: 35c57cc9dcba305d40282a9757ddc23968c210ac
2024-09-02 10:10:53 +08:00
hiyouga
a7fbae47d5 fix mm inference
Former-commit-id: fa782c15a07ed40f8a6381acdf2da395377efd04
2024-09-02 01:47:40 +08:00
hiyouga
f203a9d78e tiny fix
Former-commit-id: 8b4f408da110d74285bae20bbd969013a979964b
2024-09-02 01:33:22 +08:00
hiyouga
bae73e676c add image num check
Former-commit-id: 15201113bf16b748c0a758c7a5b363da8272e0e6
2024-09-02 01:31:36 +08:00
hiyouga
806e1061d4 add pokemon dataset
Former-commit-id: 06680158a0f0a1e3c542e77af92ac877fbe357c5
2024-09-02 01:02:25 +08:00
hiyouga
f920091667 update readme
Former-commit-id: 25a05d9f96718e06ce83f5bb1f41d2c001790295
2024-09-01 23:32:39 +08:00
hiyouga
801979f779 update wechat
Former-commit-id: 7f88dfe080db10ff12d1fb80b43099a356c899ea
2024-09-01 23:30:57 +08:00
hoshi-hiyouga
df2d32e7aa Merge pull request #5317 from ByronHsu/patch-1
Add liger kernel link

Former-commit-id: a319b3cf9119fd49cbcfb17b963e111a2f86bb51
2024-09-01 23:30:12 +08:00
hiyouga
60cf12727b add rlhf-v dataset
Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
2024-09-01 22:57:41 +08:00
hiyouga
7621526d22 tiny fix
Former-commit-id: 8ccaae3871d8d1fe3ea4633d427aecb2ab3addec
2024-09-01 21:15:44 +08:00
hiyouga
559b84dceb fix bug
Former-commit-id: 6e19e56000dd18d5faf84ceabce8d7708ff21e4d
2024-09-01 21:07:49 +08:00
hiyouga
7e4c5d4bb3 fix mixed mm inputs and rlhf-v
Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
2024-09-01 20:52:47 +08:00
Byron Hsu
2a4ed6610e Add liger kernel link
Former-commit-id: 4f313044cf8efd9c6ebcbd4741f6f38d56804b7f
2024-08-30 17:16:16 -07:00
hiyouga
1d8e9c7897 fix ci (temp)
Former-commit-id: 9ebaafd2e5c16ecef0243e4df77344ed7c823e57
2024-08-31 02:03:56 +08:00
hiyouga
43654028eb add test mm plugin
Former-commit-id: ddea5cca5a3174de1dcc7fdee8ec69e77700b6bf
2024-08-31 01:53:38 +08:00
hiyouga
2f6fc27c8b remove visual_inputs, fix qlora
Former-commit-id: be30c01c4f1482520ece770bd54c6a4837c26f0a
2024-08-31 00:24:51 +08:00
hiyouga
d789b667d7 optimize predict vram
Former-commit-id: a577e44eee351b3ed8011a33ae01cd713354ff97
2024-08-30 23:08:45 +08:00
hiyouga
66a1abac6a add examples
Former-commit-id: 169c68921b1b8ac279834b060d9e7d38a56fe1aa
2024-08-30 21:43:19 +08:00
hiyouga
665db18661 tiny fix
Former-commit-id: 830511a6d0216da99520aee8b3a753d347a71fa9
2024-08-30 03:21:50 +08:00
hiyouga
30d97ca879 fix #5307
Former-commit-id: 63c19ddfe483a16c1c9afc2f1441e8070bb0f7e4
2024-08-30 02:45:40 +08:00
hiyouga
c62a6ca59d refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
2024-08-30 02:14:31 +08:00
hoshi-hiyouga
77c2c7076b Merge pull request #5290 from simonJJJ/qwen2_vl
support qwen2-vl

Former-commit-id: 7156f832af8505b26371559d340c0e69eb962bbc
2024-08-30 02:10:36 +08:00
hoshi-hiyouga
7466fd4387 fix bug
Former-commit-id: 365e6df71509569f59c40743c115f1a4b945ef0f
2024-08-30 02:05:26 +08:00
hiyouga
c1369a1ec9 update liger kernel
Former-commit-id: d6bf6ca2161c99dd5d644e31d2b1df451017b68c
2024-08-29 20:46:08 +08:00
hiyouga
d677fe053d fix #5292
Former-commit-id: dd81ce8ce5fdf450027c5f9634abb6ac2cd52128
2024-08-29 20:37:47 +08:00
hiyouga
7c6785d3df fix #5295
Former-commit-id: c76873b0eb8225f6e6bfc7223c6012387dceb8ed
2024-08-29 20:30:18 +08:00
hiyouga
77341ee3c4 fix #5305
Former-commit-id: a710ebaf97c258c802f24e508d83f1f3f10edc6d
2024-08-29 20:16:01 +08:00
simonJJJ
5b4b60cfb5 update
Former-commit-id: a968a416d5e513320c97109229ca1e6ddc003cb1
2024-08-28 20:22:46 +08:00
simonJJJ
0f3d54d8a0 initial-commit
Former-commit-id: b6a39847a10b417b09db4b5512dd835e9e4ce928
2024-08-28 16:51:35 +08:00
hiyouga
7272792f65 update wechat
Former-commit-id: ef91752cc6f53088eaf7fc2f64f7148821d82ec2
2024-08-27 12:55:23 +08:00
hiyouga
4cc8e16595 add extra requires
Former-commit-id: c47511773ae9886aae4e5ea1841866d2125abc34
2024-08-27 12:52:12 +08:00
hiyouga
ca5a759f94 tiny fix
Former-commit-id: d2cede7023bbe28525ef8b4ad27247445d8c22e5
2024-08-27 12:49:32 +08:00
hoshi-hiyouga
be51e56a2e Merge pull request #5237 from marko1616/patch-1
Fix mllm api

Former-commit-id: 017703c7ab7f3dc566792619537c3202ca4f4bb7
2024-08-27 12:24:43 +08:00
marko1616
3a9171e275 ruff pass.
Former-commit-id: c2f817772f8e7d947dca04f546befc70001abe64
2024-08-27 11:30:16 +08:00
marko1616
bd0f3b4050 Update chat.py
Former-commit-id: 4e5893a5c4a47ff3cb989bbef0841effc713fc08
2024-08-27 11:27:56 +08:00
hiyouga
206a8364d4 support liger kernel
Former-commit-id: 0f4e54abf6c5feb2329855a4047597ad5147720a
2024-08-27 11:20:14 +08:00
marko1616
097d031066 Force re check.
Former-commit-id: 5f04452f7d65e535d0af08944f7b9e29e85f51d7
2024-08-23 14:43:18 +08:00
marko1616
2674b42b59 Update chat.py
Former-commit-id: 206a16c17d253956afb96daea6f24478e17334fc
2024-08-22 12:24:34 +08:00
marko1616
edf2e51bbc Update chat.py
Former-commit-id: edf6dc1995daa6c3635c3fda1052b340693a04f5
2024-08-22 12:14:34 +08:00
MengqingCao
47877acc2a update npu base image
Former-commit-id: 20819f7707cfff6b951484e91fc7ecda2bf68528
2024-08-21 09:12:38 +00:00
hiyouga
d111a324bc tiny fix
Former-commit-id: 23961bdf6fdbcde64e7b943f699fdeb4ac024043
2024-08-20 00:10:52 +08:00
hoshi-hiyouga
388f0a6e05 Merge pull request #5156 from YeQiuO/main
fix Llama-template's system prompt bug

Former-commit-id: 0b57175d3bd029675dae2f55995b7eeb4e9adc7a
2024-08-20 00:09:03 +08:00
hoshi-hiyouga
8c13c02c55 Update template.py
Former-commit-id: f5a075cb1c90f05bb0de26c6aea718f556c54623
2024-08-20 00:03:33 +08:00
hoshi-hiyouga
a101fde917 Merge pull request #5163 from liu-zichen/fix_ppo_optim
fix lr not change

Former-commit-id: f3c03ec6a89bf57f290820fa31eda24291355e4e
2024-08-19 23:56:24 +08:00
hoshi-hiyouga
1f4373b6e5 Merge pull request #5185 from chenhuiyu/feature/add-sailorllm-template
Add SailorLLM template

Former-commit-id: 28387d6b2f9e3bcc6321345c46b525c8180ebf7e
2024-08-19 23:51:49 +08:00
hoshi-hiyouga
525747b472 Merge pull request #5188 from Zxilly/main
fix: report correct device count for intel xpu
Former-commit-id: cd3c536cb3936061d905256850b0e57df4498010
2024-08-19 23:51:39 +08:00
hoshi-hiyouga
472f12c985 Merge pull request #5193 from Ricardo-L-C/main
_is_bf16_available judgment supports npu

Former-commit-id: 18b9ac49c45af773a2ea563f5e1852dc4b775db8
2024-08-19 23:40:59 +08:00
hoshi-hiyouga
b681f24f43 Update template.py
Former-commit-id: c6822a217e1c296f4aedd9a2c7610acd1dbd443e
2024-08-19 23:40:16 +08:00
hiyouga
fd02b089b6 update readme
Former-commit-id: 756e438866876fa54495cf557dd1e299b17a42fb
2024-08-19 23:32:04 +08:00
Ricardo
57d4c4a4f8 _is_bf16_available judgment supports npu
Former-commit-id: 50a1e892a1005b4cdd82dca1005f71db08ed89a2
2024-08-16 02:58:22 +00:00
Zxilly
3595d26846 fix: report correct device count for intel xpu
Former-commit-id: 0618f660b6511599365bd9be64499dbab41a79ba
2024-08-15 08:30:43 +00:00
Huiyu Chen
22a79c169d Add SailorLLM template
Former-commit-id: a594abe0321a718394a97b5a48ded16e2012c1f0
2024-08-15 15:10:14 +08:00
liu-zichen
75dfe259cf fix lr not change
Former-commit-id: 387dd2d51b5d8cd666459040fdd16525b34720d9
2024-08-13 16:33:34 +08:00
codingma
2e257d6af0 add tutorial and doc links
Former-commit-id: 4f6072562a34e0ec97471210ff54244cf0d0f3df
2024-08-13 16:13:10 +08:00
“Wzw”
e734222373 fix Llama-template's system prompt bug
Former-commit-id: 2e3eddcd0918b0c968ded0df7c82e3dcff870381
2024-08-12 19:22:12 +08:00
hiyouga
6a351b9912 update readme
Former-commit-id: 4fecc5ee56873a7ab4941e46a5168cfe2ecb4bb6
2024-08-10 10:17:35 +08:00
hiyouga
cfc04aa162 update readme
Former-commit-id: fa7bc9f1c7347153f9092ffbbb8e88c6b2f59632
2024-08-09 20:46:02 +08:00
hiyouga
943c795318 add magpie ultra dataset
Former-commit-id: 3317b24329b87e30f13a78936ac5554f211abf7a
2024-08-09 20:28:55 +08:00
hiyouga
7fb61bad04 add qwen2 math models
Former-commit-id: 72ff43a1772c9de5ff914d5e1c8bdc8dea9ae0c8
2024-08-09 20:20:35 +08:00
hiyouga
47efcdb1dd update examples
Former-commit-id: d5c57c8b7f64afe8061045ec9689abbac45c1175
2024-08-09 20:13:46 +08:00
hiyouga
59cbce1a46 add adam_mini to readme
Former-commit-id: d610c6bcf8a8ba6f4236f5d11f79571b83f4fb11
2024-08-09 20:02:03 +08:00
hoshi-hiyouga
7e755e9cac Merge pull request #5095 from relic-yuexi/feat-optimizer
Feat optimizer

Former-commit-id: f08390d252d42a812b71a08daba7339cc40889b7
2024-08-09 19:51:33 +08:00
hiyouga
9d1e2c3c1f update scripts
Former-commit-id: dabf5a1dc661a6581474c6a5ec115322d168ed5f
2024-08-09 19:16:23 +08:00
hiyouga
5af32ce705 follow #5115
Former-commit-id: 7d917e03e2df570139bae18227d9c7303a12de2a
2024-08-09 18:03:00 +08:00
hoshi-hiyouga
4e8861e653 Merge pull request #5115 from YeQiuO/main
fix: `Train on the last turn only` truncate bug
Former-commit-id: 2c6dae45f7a7b72c961489ac407b1b444ab7752e
2024-08-09 17:58:27 +08:00
hoshi-hiyouga
d4d7ffb17c Merge pull request #5072 from relic-yuexi/main
fix the deepseekcoder template to avoid repeat problem

Former-commit-id: 2ae7d5c91725eab9f994015d8d3577894c7978b6
2024-08-09 16:35:21 +08:00
hoshi-hiyouga
46f834ec75 Update template.py
Former-commit-id: ae2a5221c109ae3474d219c37433be767abbee91
2024-08-09 16:27:42 +08:00
“Wzw”
6ec64a7e56 mask_history args verify valid
Former-commit-id: 2f8388b4f4195d934400ad9267d72e10ca4105a3
2024-08-08 10:12:01 +08:00
“Wzw”
d71446e387 fix mask_history tiny bug
Former-commit-id: cac07aac6196be026f723b2397a343d4fb675973
2024-08-08 10:09:33 +08:00
codingma
eada49e56b fix eval_dataset in example
Former-commit-id: e1ffc54f7e58419cc8da958a4d3c2697e18d5583
2024-08-07 18:24:19 +08:00
moontidef
8f42d7df56 feat: add support for adammini
Former-commit-id: a2d5fafb705ff44db1711e972490f0abebc2012b
2024-08-07 10:08:22 +08:00
moontidef
33a90b9026 fix: rename optimzer to optimizer
Former-commit-id: 186dc1fde822e6a603ac273538741ea3853f243e
2024-08-07 10:05:01 +08:00
moontidef
710902b0d0 Merge branch 'hiyouga:main' into main
Former-commit-id: d1b23283e0e4286f126d38d7bdc55802f74c8922
2024-08-06 00:18:45 +08:00
moontidef
7b4f5d3b21 fix: fix the deepseekcoder template to avoid repeat problem
Former-commit-id: 56294831115f095135f72490a8a435434b2f0a11
2024-08-05 23:55:45 +08:00
hiyouga
13093963b1 fix #5048
Former-commit-id: 71a6861667ae68c1fd6a69acf68e1359b858cf1b
2024-08-05 23:48:19 +08:00
hoshi-hiyouga
2e477e7458 Merge pull request #5037 from codemayq/feature-gemma-2-2b
support gemma-2-2b

Former-commit-id: 6af51fadff92cd3e665c556ac073a1876f792ada
2024-08-05 23:27:37 +08:00
codingma
4b6252151e support gemma-2-2b
Former-commit-id: 7037192cf6049fd7d675aed4a6237ed929c6b170
2024-08-01 13:45:48 +08:00
hoshi-hiyouga
f3765d1996 Merge pull request #5010 from Eruly/main
Add Korean web UI (llamafactory-cli webui)

Former-commit-id: 2050806aa826028df45c0c746b4314afe178dcd3
2024-07-30 01:55:54 +08:00
hoshi-hiyouga
1f5cdd66b7 Merge pull request #4996 from LDLINGLINGLING/main
增加了MiniCPM在页面首页的支持列表,MiniCPM官方github也放了LLama_factory的友情链接

Former-commit-id: a86a776fb0f75697b0fee7694a5a0d6bd04fee0a
2024-07-30 01:55:30 +08:00
hoshi-hiyouga
5b0ddbb835 Update README_zh.md
Former-commit-id: 922906faf2d432def7cfdac82f90472fa1bb24a9
2024-07-30 01:55:13 +08:00
hoshi-hiyouga
4f92b56f06 Update README.md
Former-commit-id: 6bc7f71940be0a8f1614f9036b9c539ce46d34e1
2024-07-30 01:53:19 +08:00
hoshi-hiyouga
a1f6ff92be Update README.md
Former-commit-id: 54eecdec0da06677ea55847c74642d0fc12d8908
2024-07-30 01:52:35 +08:00
hoshi-hiyouga
ef98e91618 Merge pull request #4995 from codemayq/fix-pissa
fix pissa callback

Former-commit-id: 052c0f6bd9e872ea325b5a6aef98c4c070733384
2024-07-30 01:47:25 +08:00
eruly
9fdf800750 Add Korean web UI (llamafactory-cli webui)
Former-commit-id: 357a035f2aeb9548368c230c5a17dcdfa4844b17
2024-07-29 13:47:13 +00:00
liudan
32c698e4c2 增加了MiniCPM在页面首页的支持列表,MiniCPM官方github也放了LLama_factory的友情链接
Former-commit-id: f482a6e2fd30aff5113e53f3f07b4649982bcc2e
2024-07-29 10:58:28 +08:00
codingma
75e80fa820 fix pissa save
Former-commit-id: 25a1dad7c8df79c15efecb8c6f871a13a327f57a
2024-07-29 10:44:34 +08:00
hiyouga
f8329bc632 tiny fix
Former-commit-id: 183d8bd500a8e9513a077161ba8e8d61bea9200f
2024-07-26 11:51:00 +08:00
hoshi-hiyouga
9f74d36ba4 Merge pull request #4892 from piamo/main
update deepseek template

Former-commit-id: 3233efc8404972098665286d9dec7312dd6ecfab
2024-07-26 11:49:34 +08:00
hoshi-hiyouga
fc2435f135 Merge pull request #4950 from liuwwang/main and fi
fix: Repair the issue where quantization failed after merging the adapter.
Former-commit-id: 93a68ea1f4372973f745a2c250250ecaac515e27
2024-07-26 11:48:56 +08:00
hoshi-hiyouga
0636519ba3 Merge pull request #4970 from HardAndHeavy/add-rocm
Add ROCm support

Former-commit-id: c0f21d869bce6e59825d57c66bce3fe54f50065f
2024-07-26 11:41:23 +08:00
hoshi-hiyouga
573bf03a6f Update README_zh.md
Former-commit-id: 86a27a97ff67b0d4bcd671c62759cd049542dc1b
2024-07-26 11:30:57 +08:00
hoshi-hiyouga
9e529be4e7 Update README.md
Former-commit-id: 1c167bb2ea3a47bdeeccc044a653662132c61698
2024-07-26 11:29:28 +08:00
hoshi-hiyouga
7af4ffa6cc Update README.md
Former-commit-id: d6e7a69c274c3756587e18a039637dd37fa152b2
2024-07-26 11:29:09 +08:00
HardAndHeavy
5b67ccd1c6 Add ROCm support
Former-commit-id: cf9df10a24936efd420b0fdac541fd6c0808a327
2024-07-25 21:29:28 +03:00
khazic
5166dbbcd3 Added the reference address for TRL PPO details.
Former-commit-id: 509c55608643eae3a6456683d425a7c636cfc3e9
2024-07-25 09:03:21 +08:00
hiyouga
21adb09730 fix #4959
Former-commit-id: 96e8a1d47874708c6157865c78be4cd6c533e01b
2024-07-24 23:44:00 +08:00
hiyouga
28b5f656db update webui
Former-commit-id: 463edec1b1c1345afc791e225deb33f118f3582e
2024-07-24 21:11:51 +08:00
hoshi-hiyouga
68ee2d512f Update README_zh.md
Former-commit-id: 1443e876697e18108573387e501a7453ba9fc06c
2024-07-24 21:08:42 +08:00
hoshi-hiyouga
a5f7e0efc6 Update README.md
Former-commit-id: 07d86e38cfd857d1dfa898541f3e5bd9c6f11581
2024-07-24 21:07:14 +08:00
hiyouga
211038584a tiny fix
Former-commit-id: 28cac0e325bfd7a6c0c344ad2d46511613190cd7
2024-07-24 18:33:39 +08:00
hiyouga
ff5ba97970 fix #4928
Former-commit-id: 6d557e8959678f9d4edbcb3d5a6dfba14b429b18
2024-07-24 17:00:29 +08:00
hiyouga
27f2c3cae1 fix #4925
Former-commit-id: 79c336e2339974471627487858d59e4ed2152370
2024-07-24 16:56:58 +08:00
hiyouga
48f0819327 fix #4944
Former-commit-id: 9e8cf3b21a0b12d1413c3c7f3d60399784909242
2024-07-24 16:42:51 +08:00
hiyouga
5c6d88e91c add mistral nemo model
Former-commit-id: 428bb49f53b32947bc0a62ca19ab10844154c07c
2024-07-24 16:25:53 +08:00
hiyouga
0a04d9470f add llama3.1
Former-commit-id: 3c433890f9b61c520572f5233aae70584da0f330
2024-07-24 16:20:11 +08:00
Liuww
f0408c0dde fix: Repair the issue where quantization failed after merging the adapter.
Former-commit-id: 8109561b7f577d448f8bca7e569f7f443cf6bb52
2024-07-24 14:31:29 +08:00
hiyouga
a041f4a111 tiny fix
Former-commit-id: bf6a2f032c598f969708c1c3db4875d6239c41a9
2024-07-22 21:10:15 +08:00
hoshi-hiyouga
cdf9dae53e fix #4917
Former-commit-id: e26919aafd8436489d065789c9c25d72c8d05a6d
2024-07-22 11:28:31 +08:00
hiyouga
1917f431f5 tiny fix
Former-commit-id: 9133316e558a3c8744f5eb6ab8678686bf4859ed
2024-07-22 00:06:03 +08:00
hiyouga
a770afbff2 fix flashattn + packing
Former-commit-id: 4adc6ce4abc718c25f39b316bfc3352d0d01ed1e
2024-07-21 17:07:45 +08:00
huangpan.foo
b1a5bf025b update deepseek template
Former-commit-id: f5ca86ec95bb301df42ffaa6923fc3037a224e34
2024-07-19 15:02:54 +08:00
hiyouga
adff3e5050 set dev version
Former-commit-id: 0b9a2275dc533b65578278f979ce053e95a644b3
2024-07-19 02:01:46 +08:00
127 changed files with 4083 additions and 1316 deletions

35
.env.local Normal file
View File

@@ -0,0 +1,35 @@
# Note: actually we do not support .env, just for reference
# api
API_HOST=0.0.0.0
API_PORT=8000
API_KEY=
API_MODEL_NAME=gpt-3.5-turbo
FASTAPI_ROOT_PATH=
# general
DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS=
FORCE_TORCHRUN=
LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB=
RECORD_VRAM=
# torchrun
FORCE_TORCHRUN=
MASTER_ADDR=
MASTER_PORT=
NNODES=
RANK=
NPROC_PER_NODE=
# wandb
WANDB_DISABLED=
WANDB_PROJECT=huggingface
WANDB_API_KEY=
# gradio ui
GRADIO_SHARE=False
GRADIO_SERVER_NAME=0.0.0.0
GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH=
# setup
ENABLE_SHORT_CONSOLE=1
# reserved (do not use)
LLAMABOARD_ENABLED=
LLAMABOARD_WORKDIR=

View File

@@ -3,14 +3,14 @@ name: tests
on:
push:
branches:
- main
- "main"
paths:
- "**.py"
- "requirements.txt"
- ".github/workflows/*.yml"
pull_request:
branches:
- main
- "main"
paths:
- "**.py"
- "requirements.txt"
@@ -18,13 +18,27 @@ on:
jobs:
tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
os:
- "ubuntu-latest"
- "windows-latest"
- "macos-13"
runs-on: ${{ matrix.os }}
environment:
name: tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
steps:
- name: Checkout
@@ -33,13 +47,14 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: "setup.py"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install git+https://github.com/huggingface/transformers.git
python -m pip install ".[torch,dev]"
- name: Check quality

2
.gitignore vendored
View File

@@ -160,6 +160,8 @@ cython_debug/
.idea/
# custom .gitignore
ms_cache/
hf_cache/
cache/
config/
saves/

View File

@@ -1,6 +1,6 @@
.PHONY: quality style test
check_dirs := scripts src tests
check_dirs := scripts src tests setup.py
quality:
ruff check $(check_dirs)

184
README.md
View File

@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-72-green)](#projects-using-llama-factory)
[![Citation](https://img.shields.io/badge/citation-91-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -21,13 +21,14 @@
**Fine-tuning a large language model can be easy as...**
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6
https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
Choose your path:
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Local machine**: Please refer to [usage](#getting-started)
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
## Table of Contents
@@ -46,11 +47,11 @@ Choose your path:
## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -71,15 +72,23 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
[24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
<details><summary>Full Changelog</summary>
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary>
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
@@ -91,7 +100,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
@@ -103,7 +112,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage.
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
@@ -151,32 +160,34 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models
| Model | Model size | Template |
| ------------------------------------------------------------ | -------------------------------- | --------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
| Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@@ -200,6 +211,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!TIP]
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
## Provided Datasets
<details><summary>Pre-training datasets</summary>
@@ -259,7 +273,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -276,6 +292,8 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -296,20 +314,20 @@ huggingface-cli login
| Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- |
| python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 |
| transformers | 4.41.2 | 4.41.2 |
| datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.30.1 | 0.30.1 |
| peft | 0.11.1 | 0.11.1 |
| trl | 0.8.6 | 0.9.4 |
| torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.6 |
| Optional | Minimum | Recommend |
| ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.9 |
| vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.6.3 |
### Hardware Requirement
@@ -338,7 +356,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
> [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts.
@@ -422,16 +440,24 @@ For CUDA users:
```bash
cd docker/docker-cuda/
docker-compose up -d
docker-compose exec llamafactory bash
docker compose up -d
docker compose exec llamafactory bash
```
For Ascend NPU users:
```bash
cd docker/docker-npu/
docker-compose up -d
docker-compose exec llamafactory bash
docker compose up -d
docker compose exec llamafactory bash
```
For AMD ROCm users:
```bash
cd docker/docker-rocm/
docker compose up -d
docker compose exec llamafactory bash
```
<details><summary>Build without Docker Compose</summary>
@@ -493,13 +519,42 @@ docker run -dit \
docker exec -it llamafactory bash
```
For AMD ROCm users:
```bash
docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \
-p 8000:8000 \
--device /dev/kfd \
--device /dev/dri \
--shm-size 16G \
--name llamafactory \
llamafactory:latest
docker exec -it llamafactory bash
```
</details>
<details><summary>Details about volume</summary>
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
</details>
@@ -510,7 +565,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
### Download from ModelScope Hub
@@ -600,7 +655,26 @@ If you have a project that should be incorporated, please contact via email or c
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburghs Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
@@ -618,7 +692,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation

View File

@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-72-green)](#使用了-llama-factory-的项目)
[![Citation](https://img.shields.io/badge/citation-91-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -21,13 +21,15 @@
**微调大模型可以像这样轻松…**
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd-d76c6d0a6594
https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **PAI-DSW**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **本地机器**:请见[如何使用](#如何使用)
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
## 目录
@@ -46,11 +48,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -71,15 +73,23 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
<details><summary>展开日志</summary>
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary>
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `paligemma` 模板进行微调使其获得对话能力。
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
@@ -91,7 +101,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
@@ -103,7 +113,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。
@@ -151,32 +161,34 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型
| 模型名 | 模型大小 | Template |
| ------------------------------------------------------------ | -------------------------------- | --------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
| 模型名 | 模型大小 | Template |
| ----------------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
@@ -200,6 +212,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!TIP]
> 有关 PPO 的实现细节,请参考[此博客](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html)。
## 数据集
<details><summary>预训练数据集</summary>
@@ -259,7 +274,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -276,6 +293,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -296,20 +315,20 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- |
| python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 |
| transformers | 4.41.2 | 4.41.2 |
| datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.30.1 | 0.30.1 |
| peft | 0.11.1 | 0.11.1 |
| trl | 0.8.6 | 0.9.4 |
| torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.6 |
| 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.9 |
| vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.6.3 |
### 硬件依赖
@@ -338,7 +357,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
> [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@@ -422,16 +441,24 @@ CUDA 用户:
```bash
cd docker/docker-cuda/
docker-compose up -d
docker-compose exec llamafactory bash
docker compose up -d
docker compose exec llamafactory bash
```
昇腾 NPU 用户:
```bash
cd docker/docker-npu/
docker-compose up -d
docker-compose exec llamafactory bash
docker compose up -d
docker compose exec llamafactory bash
```
AMD ROCm 用户:
```bash
cd docker/docker-rocm/
docker compose up -d
docker compose exec llamafactory bash
```
<details><summary>不使用 Docker Compose 构建</summary>
@@ -493,13 +520,42 @@ docker run -dit \
docker exec -it llamafactory bash
```
AMD ROCm 用户:
```bash
docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \
-p 8000:8000 \
--device /dev/kfd \
--device /dev/dri \
--shm-size 16G \
--name llamafactory \
llamafactory:latest
docker exec -it llamafactory bash
```
</details>
<details><summary>数据卷详情</summary>
- hf_cache使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- data宿主机中存放数据集的文件夹路径
- output将导出目录设置为该路径后即可在宿主机中访问导出后的模型
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供
- `data`:宿主机中存放数据集的文件夹路径
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
</details>
@@ -510,7 +566,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)
### 从魔搭社区下载
@@ -600,7 +656,26 @@ run_name: test_run # 可选
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburghs Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
@@ -618,7 +693,7 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用

View File

@@ -23,6 +23,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)",
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
@@ -107,7 +108,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
### Preference Dataset
Preference datasets are used for reward modeling, DPO training and ORPO training.
Preference datasets are used for reward modeling, DPO training, ORPO and SimPO training.
It requires a better response in `chosen` column and a worse response in `rejected` column.
@@ -139,67 +140,15 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
### KTO Dataset
- [Example dataset](kto_en_demo.json)
An additional column `kto_tag` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
### Multimodal Image Dataset
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"kto_tag": "human feedback [true/false] (required)"
}
]
```
An additional column `images` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
### Multimodal Video Dataset
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Dataset
- [Example dataset](mllm_demo.json)
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image.
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
## Sharegpt Format
@@ -252,6 +201,10 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
}
```
### Pre-training Dataset
Not yet supported, please use the [alpaca](#alpaca-format) format.
### Preference Dataset
- [Example dataset](dpo_en_demo.json)
@@ -302,6 +255,125 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
}
```
### KTO Dataset
- [Example dataset](kto_en_demo.json)
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
```json
[
{
"conversations": [
{
"from": "human",
"value": "human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"kto_tag": "human feedback [true/false] (required)"
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Image Dataset
- [Example dataset](mllm_demo.json)
Multimodal image datasets require a `images` column containing the paths to the input images.
The number of images should be identical to the `<image>` tokens in the conversations.
```json
[
{
"conversations": [
{
"from": "human",
"value": "<image>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"images": "images"
}
}
```
### Multimodal Video Dataset
- [Example dataset](mllm_video_demo.json)
Multimodal video datasets require a `videos` column containing the paths to the input videos.
The number of videos should be identical to the `<video>` tokens in the conversations.
```json
[
{
"conversations": [
{
"from": "human",
"value": "<video>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"videos": [
"video path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"videos": "videos"
}
}
```
### OpenAI Format
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
@@ -345,7 +417,3 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
}
}
```
The KTO datasets and multimodal datasets in sharegpt format are similar to the alpaca format.
Pre-training datasets are **incompatible** with the sharegpt format.

View File

@@ -23,6 +23,7 @@
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None",
"videos": "数据集代表视频输入的表头名称默认None",
"chosen": "数据集代表更优回答的表头名称默认None",
"rejected": "数据集代表更差回答的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
@@ -107,7 +108,7 @@
### 偏好数据集
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。
偏好数据集用于奖励模型训练、DPO 训练、ORPO 训练和 SimPO 训练。
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
@@ -139,67 +140,15 @@
### KTO 数据集
- [样例数据集](kto_en_demo.json)
KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#sharegpt-格式)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
### 多模态图像数据集
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
多模态图像数据集需要提供额外的 `images` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
### 多模态视频数据集
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### 多模态数据集
- [样例数据集](mllm_demo.json)
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
## Sharegpt 格式
@@ -252,6 +201,10 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
}
```
### 预训练数据集
尚不支持,请使用 [alpaca](#alpaca-格式) 格式。
### 偏好数据集
- [样例数据集](dpo_zh_demo.json)
@@ -302,6 +255,125 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
}
```
### KTO 数据集
- [样例数据集](kto_en_demo.json)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
```json
[
{
"conversations": [
{
"from": "human",
"value": "人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"kto_tag": "kto_tag"
}
}
```
### 多模态图像数据集
- [样例数据集](mllm_demo.json)
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
```json
[
{
"conversations": [
{
"from": "human",
"value": "<image>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"images": "images"
}
}
```
### 多模态视频数据集
- [样例数据集](mllm_video_demo.json)
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
```json
[
{
"conversations": [
{
"from": "human",
"value": "<video>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"videos": [
"视频路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"videos": "videos"
}
}
```
### OpenAI 格式
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
@@ -345,7 +417,3 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
}
}
```
Sharegpt 格式中的 KTO 数据集和多模态数据集与 alpaca 格式的类似。
预训练数据集**不支持** sharegpt 格式。

BIN
data/mllm_demo_data/1.mp4 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/2.avi Normal file

Binary file not shown.

BIN
data/mllm_demo_data/3.mp4 Normal file

Binary file not shown.

View File

@@ -1,9 +1,9 @@
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags
# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04
FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04
# FROM cosdt/cann:8.0.rc1-910-openeuler22.03
# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03
# More versions can be found at https://hub.docker.com/r/ascendai/cann/tags
# FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8
FROM ascendai/cann:8.0.rc1-910b-ubuntu22.04-py3.8
# FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8
# FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8
# Define environments
ENV DEBIAN_FRONTEND=noninteractive

View File

@@ -0,0 +1,57 @@
FROM hardandheavy/transformers-rocm:2.1.0
# Define environments
ENV MAX_JOBS=4
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
# Define installation arguments
ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false
ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory
WORKDIR /app
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
# Copy the rest of the application into the image
COPY . /app
# Install the LLaMA Factory
RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_BNB" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
fi; \
if [ "$INSTALL_VLLM" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
fi; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi
# Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
# Expose port 7860 for the LLaMA Board
ENV GRADIO_SERVER_PORT 7860
EXPOSE 7860
# Expose port 8000 for the API service
ENV API_PORT 8000
EXPOSE 8000

View File

@@ -0,0 +1,29 @@
services:
llamafactory:
build:
dockerfile: ./docker/docker-rocm/Dockerfile
context: ../..
args:
INSTALL_BNB: false
INSTALL_VLLM: false
INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes:
- ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope
- ../../data:/app/data
- ../../output:/app/output
- ../../saves:/app/saves
ports:
- "7860:7860"
- "8000:8000"
ipc: host
tty: true
stdin_open: true
command: bash
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
restart: unless-stopped

View File

@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### Multimodal DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
```
#### Reward Modeling
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO Training
```bash
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### Multimodal Supervised Fine-Tuning
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
```
#### Full-Parameter Fine-Tuning using Adam-mini
```bash
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
```
#### LoRA+ Fine-Tuning
```bash

View File

@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### 多模态 DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
```
#### 奖励模型训练
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO 训练
```bash
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### 多模态指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
```
#### 使用 Adam-mini 进行全参数训练
```bash
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
```
#### LoRA+ 微调
```bash

View File

@@ -1,27 +1,22 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen2-1.5B-Instruct
### method
stage: sft
do_train: true
finetuning_type: full
use_badam: true
badam_mode: layer
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2
deepspeed: examples/deepspeed/ds_z3_config.json
use_adam_mini: true
### dataset
dataset: identity,alpaca_en_demo
template: llama3
template: qwen
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/llama3-8b/full/sft
output_dir: saves/qwen2-1_5b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
@@ -30,10 +25,12 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1

View File

@@ -10,6 +10,7 @@ badam_mode: layer
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2
# deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_en_demo
@@ -29,7 +30,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1

View File

@@ -29,11 +29,12 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
pure_bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1

View File

@@ -2,5 +2,5 @@
python scripts/llama_pro.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir models/llama3-8b-instruct-pro \
--output_dir models/llama3-8b-pro \
--num_expand 8

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: models/llama3-8b-instruct-pro
model_name_or_path: models/llama3-8b-pro
### method
stage: sft
@@ -18,7 +18,7 @@ overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/llama3-8b-instruct-pro/freeze/sft
output_dir: saves/llama3-8b-pro/freeze/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -26,7 +26,7 @@ overwrite_output_dir: true
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
optim: paged_adamw_8bit
learning_rate: 1.0e-4
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1

View File

@@ -0,0 +1,5 @@
#!/bin/bash
python scripts/pissa_init.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir models/llama3-8b-pissa

View File

@@ -1,3 +1,2 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf
template: vicuna
visual_inputs: true
template: llava

View File

@@ -0,0 +1,2 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
template: qwen2_vl

View File

@@ -0,0 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
template: qwen2_vl
finetuning_type: lora
### export
export_dir: models/qwen2_vl_lora_sft
export_size: 2
export_device: cpu
export_legacy_format: false

View File

@@ -7,7 +7,7 @@ do_predict: true
finetuning_type: full
### dataset
dataset: identity,alpaca_en_demo
eval_dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 50

View File

@@ -25,7 +25,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-4
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1

View File

@@ -0,0 +1,39 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: mllm_demo,identity
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -1,6 +1,5 @@
### model
model_name_or_path: llava-hf/llava-1.5-7b-hf
visual_inputs: true
### method
stage: sft
@@ -10,7 +9,7 @@ lora_target: all
### dataset
dataset: mllm_demo
template: vicuna
template: llava
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true

View File

@@ -0,0 +1,41 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
pref_beta: 0.1
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: rlhf_v
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/lora/dpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -0,0 +1,39 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
dataset: mllm_demo,identity # video: mllm_video_demo
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -1,8 +1,8 @@
transformers>=4.41.2
datasets>=2.16.0
accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
transformers>=4.41.2,<=4.45.0
datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1,<=0.33.0
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
gradio>=4.0.0
pandas>=2.0.0
scipy

View File

@@ -27,7 +27,7 @@ from llamafactory.chat import ChatModel
def calculate_flops(
model_name_or_path: str,
batch_size: int = 1,
seq_length: int = 256,
seq_length: int = 512,
flash_attn: str = "auto",
):
r"""
@@ -36,9 +36,11 @@ def calculate_flops(
"""
with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
flops, macs, params = get_model_profile(
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
)
print("FLOPs:", flops)
print("MACs:", macs)
print("Params:", params)

View File

@@ -25,7 +25,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
@@ -39,16 +39,17 @@ def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
"""
model_args, data_args, training_args, _, _ = get_train_args(
dict(
@@ -66,7 +67,8 @@ def calculate_lr(
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
@@ -84,7 +86,7 @@ def calculate_lr(
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len

164
scripts/cal_mfu.py Normal file
View File

@@ -0,0 +1,164 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import fire
import torch
import torch.distributed as dist
from transformers import AutoConfig
from llamafactory.train.tuner import run_exp
BASE = 2 # gemm (add + mul)
def compute_model_flops(
model_name_or_path: str,
total_batch_size: int,
seq_length: int,
include_backward: bool = True,
include_recompute: bool = False,
include_flashattn: bool = False,
) -> int:
r"""
Calculates the FLOPs of model per forward/backward pass.
"""
config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None)
intermediate_size = getattr(config, "intermediate_size", None)
num_attention_heads = getattr(config, "num_attention_heads", None)
num_key_value_heads = getattr(config, "num_key_value_heads", None)
num_hidden_layers = getattr(config, "num_hidden_layers", None)
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
# mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size
o_flops_per_token = BASE * hidden_size * hidden_size
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module
embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False:
embedding_flops *= 2
non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
non_embedding_coeff, embedding_coeff = 1, 1
if include_backward:
non_embedding_coeff += 2
embedding_coeff += 2
if include_recompute:
non_embedding_coeff += 1
total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
if include_flashattn:
total_flops += sdpa_flops
return total_flops
def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * world_size
elif "V100" in device_name:
return 125 * 1e12 * world_size
elif "4090" in device_name:
return 98 * 1e12 * world_size
else:
raise NotImplementedError("Device not supported: {}.".format(device_name))
def calculate_mfu(
model_name_or_path: str,
batch_size: int = 1,
seq_length: int = 1024,
num_steps: int = 100,
finetuning_type: str = "lora",
flash_attn: str = "auto",
deepspeed_stage: int = 0,
disable_gc: bool = False,
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {
"model_name_or_path": model_name_or_path,
"flash_attn": flash_attn,
"disable_gradient_checkpointing": disable_gc,
"enable_liger_kernel": liger_kernel,
"use_unsloth_gc": unsloth_gc,
"stage": "pt",
"do_train": True,
"finetuning_type": finetuning_type,
"dataset": "c4_demo",
"cutoff_len": seq_length,
"output_dir": os.path.join("saves", "test_mfu"),
"logging_strategy": "no",
"save_strategy": "no",
"save_only_model": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": num_steps,
"bf16": True,
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print("MFU: {:.2f}%".format(mfu_value * 100))
if __name__ == "__main__":
fire.Fire(calculate_mfu)

View File

@@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer
@@ -55,12 +55,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
return super().__call__(chosen_features)
def cal_ppl(
def calculate_ppl(
model_name_or_path: str,
save_name: str,
batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
@@ -69,7 +69,7 @@ def cal_ppl(
):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
@@ -88,7 +88,8 @@ def cal_ppl(
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
@@ -129,4 +130,4 @@ def cal_ppl(
if __name__ == "__main__":
fire.Fire(cal_ppl)
fire.Fire(calculate_ppl)

View File

@@ -18,21 +18,21 @@ from collections import defaultdict
import fire
from tqdm import tqdm
from llamafactory.data import get_dataset
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: str = "alpaca_en",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
interval: int = 1000,
):
r"""
Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""
model_args, data_args, training_args, _, _ = get_train_args(
dict(
@@ -48,7 +48,8 @@ def length_cdf(
)
)
tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"]
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):

View File

@@ -19,7 +19,7 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import fire
import torch
@@ -47,8 +47,8 @@ def block_expansion(
model_name_or_path: str,
output_dir: str,
num_expand: int,
shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False,
shard_size: str = "2GB",
save_safetensors: bool = True,
):
r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.

View File

@@ -16,7 +16,7 @@
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
from typing import Any, Dict
import fire
import torch
@@ -86,7 +86,10 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = True,
):
r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.

View File

@@ -16,7 +16,7 @@
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
from typing import Any, Dict
import fire
import torch
@@ -139,7 +139,10 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = False,
):
r"""
Converts the Qwen models in the same format as LLaMA2.

View File

@@ -67,7 +67,7 @@ def quantize_loftq(
loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir))

View File

@@ -31,7 +31,7 @@ if TYPE_CHECKING:
def quantize_pissa(
model_name_or_path: str,
output_dir: str,
pissa_iter: int = 4,
pissa_iter: int = 16,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
@@ -62,6 +62,7 @@ def quantize_pissa(
pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir))

View File

@@ -14,11 +14,12 @@
import os
import re
from typing import List
from setuptools import find_packages, setup
def get_version():
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
@@ -26,27 +27,37 @@ def get_version():
return version
def get_requires():
def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> List[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
"liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3"],
"vllm": ["vllm>=0.4.3,<=0.6.0"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"dev": ["ruff", "pytest"],
@@ -70,7 +81,7 @@ def main():
python_requires=">=3.8.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",

View File

@@ -20,22 +20,27 @@ Level:
Dependency graph:
main:
transformers>=4.41.2
datasets>=2.16.0
accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
transformers>=4.41.2,<=4.45.0
datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1,<=0.33.0
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.42.4
transformers>=4.41.2,<=4.45.0
packing:
transformers>=4.41.2,<=4.42.4
patcher:
transformers==4.41.2 (chatglm)
transformers>=4.41.2,<=4.45.0
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
"""
from .cli import VERSION
from .extras.env import VERSION
__version__ = VERSION

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from contextlib import asynccontextmanager
from functools import partial
from typing import Optional
from typing_extensions import Annotated
@@ -50,14 +52,24 @@ if is_uvicorn_available():
import uvicorn
async def sweeper() -> None:
while True:
torch_gc()
await asyncio.sleep(300)
@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
asyncio.create_task(sweeper())
yield
torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -65,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY")
api_key = os.environ.get("API_KEY", None)
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@@ -79,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo")
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card])
@app.post(

View File

@@ -16,6 +16,7 @@ import base64
import io
import json
import os
import re
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
@@ -51,9 +52,8 @@ if is_requests_available():
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@@ -69,7 +69,7 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0:
@@ -104,15 +104,14 @@ def _process_request(
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
image_path = io.BytesIO(image_data)
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb")
image_stream = open(image_url, "rb")
else: # web uri
image_path = requests.get(image_url, stream=True).raw
image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
image = Image.open(image_stream).convert("RGB")
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})

View File

@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -35,6 +35,12 @@ class Response:
class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None: ...
) -> None:
r"""
Initializes an inference engine.
"""
...
@abstractmethod
async def chat(
@@ -56,9 +66,14 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]: ...
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
...
@abstractmethod
async def stream_chat(
@@ -66,13 +81,22 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...
) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
...
@abstractmethod
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]: ...
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
...

View File

@@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..data.mm_plugin import ImageInput, VideoInput
from .base_engine import BaseEngine, Response
@@ -38,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm":
@@ -56,10 +64,16 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
)
return task.result()
async def achat(
@@ -67,20 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -93,10 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token
def get_scores(
@@ -104,6 +130,9 @@ class ChatModel:
batch_input: List[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
@@ -112,6 +141,9 @@ class ChatModel:
batch_input: List[str],
**input_kwargs,
) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
@@ -29,12 +31,11 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper
from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if (
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and template.image_token not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None:
mm_input_dict.update({"images": [image], "imglens": [1]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
)
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
@@ -164,8 +164,10 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(),
)
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length
@@ -180,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@@ -215,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@@ -267,12 +271,14 @@ class HuggingfaceEngine(BaseEngine):
return scores
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -289,18 +295,21 @@ class HuggingfaceEngine(BaseEngine):
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@@ -317,6 +326,7 @@ class HuggingfaceEngine(BaseEngine):
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
@@ -328,6 +338,7 @@ class HuggingfaceEngine(BaseEngine):
except StopAsyncIteration:
break
@override
async def get_scores(
self,
batch_input: List[str],

View File

@@ -15,32 +15,31 @@
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
if is_vllm_version_greater_than_0_5_1():
pass
elif is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.generating_args = generating_args.to_dict()
engine_args = {
@@ -85,19 +84,11 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if model_args.visual_inputs:
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
@@ -110,37 +101,18 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
if image is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if is_vllm_version_greater_than_0_5_1():
multi_modal_data = {"image": pixel_values}
elif is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
@@ -185,6 +157,17 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True,
)
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image}
else:
multi_modal_data = None
result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params,
@@ -193,16 +176,18 @@ class VllmEngine(BaseEngine):
)
return result_generator
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, image, **input_kwargs)
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for request_output in generator:
final_output = request_output
@@ -219,21 +204,24 @@ class VllmEngine(BaseEngine):
return results
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, image, **input_kwargs)
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
@override
async def get_scores(
self,
batch_input: List[str],

View File

@@ -118,4 +118,4 @@ def main():
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}".format(command))
raise NotImplementedError("Unknown command: {}.".format(command))

View File

@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
from .collator import (
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
@@ -20,6 +25,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",

View File

@@ -14,9 +14,7 @@
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras.logging import get_logger
from .data_utils import Role
@@ -27,88 +25,117 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
def _convert_images(
images: Sequence["ImageInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
if len(images) == 0:
return None
return outputs
images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i])
return images
def _convert_videos(
videos: Sequence["VideoInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if len(videos) == 0:
return None
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
return videos
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
content = []
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
content.append(examples[dataset_attr.prompt][i])
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
@@ -119,74 +146,79 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = example[dataset_attr.messages]
if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
if len(messages) == 0:
continue
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
return outputs
if broken_data:
logger.warning("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def align_dataset(
@@ -197,11 +229,12 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "...",
images: [],
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
@@ -209,19 +242,6 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {}
if not data_args.streaming:
kwargs = dict(
@@ -232,8 +252,7 @@ def align_dataset(
return dataset.map(
convert_func,
batched=True,
batched=False,
remove_columns=column_names,
features=features,
**kwargs,
)

View File

@@ -16,12 +16,18 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
if TYPE_CHECKING:
from transformers import ProcessorMixin
from .template import Template
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
@@ -62,7 +68,42 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
batch_images.extend(images)
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"]))
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs)
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
@@ -80,7 +121,7 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
@@ -99,20 +140,16 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
"images": feature["images"],
"videos": feature["videos"],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
@@ -126,19 +163,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
"images": feature["images"],
"videos": feature["videos"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
"images": feature["images"],
"videos": feature["videos"],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])
@@ -148,7 +182,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch:
if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags)

View File

@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
@@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)

View File

@@ -16,21 +16,36 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None
tool_format: Optional[str] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
...
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
raise NotImplementedError
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
@override
def apply(self, **kwargs) -> SLOTS:
return self.slots
@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
@override
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
@@ -81,13 +98,9 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
@@ -119,22 +132,17 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
return self._tool_extractor(content)
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)

View File

@@ -27,7 +27,6 @@ from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
@@ -49,6 +48,9 @@ def _load_single_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
@@ -118,7 +120,7 @@ def _load_single_dataset(
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num]
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
@@ -142,6 +144,9 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
"""
if dataset_names is None:
return None
@@ -165,6 +170,9 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None:
return None
@@ -180,7 +188,13 @@ def _get_preprocessed_dataset(
desc="Running tokenizer on dataset",
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
**kwargs,
)
if training_args.should_log:
try:
@@ -196,6 +210,7 @@ def _get_preprocessed_dataset(
def get_dataset(
template: "Template",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
@@ -203,10 +218,9 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
@@ -217,6 +231,7 @@ def get_dataset(
dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
@@ -270,6 +285,7 @@ def get_dataset(
dataset_module = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]

View File

@@ -0,0 +1,400 @@
from copy import deepcopy
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_pyav_available():
import av
if TYPE_CHECKING:
import torch
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
class EncodedImage(TypedDict):
path: Optional[str]
bytes: Optional[bytes]
ImageInput = Union[str, EncodedImage, ImageObject]
VideoInput = str
def _regularize_images(
images: Sequence["ImageInput"],
processor: "ProcessorMixin",
max_resolution: Optional[int] = None,
) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading, resizing and converting.
"""
if max_resolution is None:
max_resolution: int = getattr(processor, "image_resolution", 512)
results = []
for image in images:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, dict):
if image["bytes"] is not None:
image = Image.open(BytesIO(image["bytes"]))
else:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
if max(image.width, image.height) > max_resolution:
factor = max_resolution / max(image.width, image.height)
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
results.append(image)
return results
def _regularize_videos(
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
video_resolution: int = getattr(processor, "video_resolution", 128)
video_fps: float = getattr(processor, "video_fps", 1.0)
video_maxlen: int = getattr(processor, "video_maxlen", 64)
video_factor: int = getattr(processor, "video_factor", 1)
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = _regularize_images(frames, processor, video_resolution)
results.append(frames)
return results
def _get_mm_inputs(
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
input_dict = {"images": None} # default key
if len(images) != 0:
images = _regularize_images(images, processor)
input_dict["images"] = images
if len(videos) != 0:
videos = _regularize_videos(videos, processor)
input_dict["videos"] = videos
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
return image_processor(**input_dict, return_tensors="pt")
else:
return {}
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen")
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids
class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
def _validate_input(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
) -> None:
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(images, videos)
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(images, videos)
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
"""
self._validate_input(images, videos)
return {}
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels
return input_ids, labels
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = _get_mm_inputs(images, videos, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = _get_mm_inputs(images, videos, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
),
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
content = content.replace(
VIDEO_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
),
1,
)
num_video_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
PLUGINS = {
"base": BasePlugin,
"llava": LlavaPlugin,
"paligemma": PaliGemmaPlugin,
"qwen2_vl": Qwen2vlPlugin,
}
def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))
return plugin_class(image_token, video_token)

View File

@@ -43,6 +43,7 @@ class DatasetAttr:
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
videos: Optional[str] = None
# rlhf columns
chosen: Optional[str] = None
rejected: Optional[str] = None
@@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:

View File

@@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
@@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
else:

View File

@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
@@ -35,14 +37,13 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
@@ -55,6 +56,8 @@ def _encode_feedback_example(
else:
kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
@@ -62,15 +65,13 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len)
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
@@ -78,7 +79,6 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
@@ -88,39 +88,27 @@ def preprocess_feedback_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
prompt=examples["_prompt"][i],
response=examples["_response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
tools=examples["tools"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
@@ -129,11 +117,8 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
@@ -34,16 +36,15 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
@@ -51,13 +52,9 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
) # consider the response is more important
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
# consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
@@ -66,7 +63,6 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
@@ -76,36 +72,25 @@ def preprocess_pairwise_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
cutoff_len=data_args.cutoff_len,
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
@@ -113,15 +98,8 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs

View File

@@ -27,16 +27,16 @@ if TYPE_CHECKING:
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}

View File

@@ -13,20 +13,7 @@
# limitations under the License.
import bisect
from typing import TYPE_CHECKING, List, Sequence, Tuple
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from typing import List, Sequence, Tuple
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
@@ -61,23 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.

View File

@@ -17,13 +17,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
@@ -35,47 +36,49 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= data_args.cutoff_len:
if total_length >= cutoff_len:
break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if data_args.train_on_prompt:
if train_on_prompt:
source_label = source_ids
elif turn_idx != 0 and template.efficient_eos:
elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
input_ids += source_ids + target_ids
labels += source_label + target_label
if mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
@@ -90,37 +93,34 @@ def preprocess_supervised_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
@@ -129,28 +129,34 @@ def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels = [], []
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=None,
data_args=data_args,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
@@ -160,16 +166,21 @@ def preprocess_packed_supervised_dataset(
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
@@ -190,6 +201,8 @@ def preprocess_packed_supervised_dataset(
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
return model_inputs

View File

@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
@@ -34,28 +36,25 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
cutoff_len: int,
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
@@ -67,36 +66,31 @@ def preprocess_unsupervised_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue
input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs

View File

@@ -15,15 +15,21 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
logger = get_logger(__name__)
@@ -41,9 +47,9 @@ class Template:
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
self,
@@ -147,6 +153,7 @@ class Template:
@dataclass
class Llama2Template(Template):
@override
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
@@ -190,7 +197,7 @@ class Llama2Template(Template):
return encoded_messages
TEMPLATES: Dict[str, Template] = {}
TEMPLATES: Dict[str, "Template"] = {}
def _register_template(
@@ -205,9 +212,9 @@ def _register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Sequence[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False,
replace_eos: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None:
r"""
Registers a chat template.
@@ -254,9 +261,9 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
mm_plugin=mm_plugin,
)
@@ -300,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
@@ -310,14 +320,15 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}"
jinja_template += "{% for message in loop_messages %}"
jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
@@ -338,23 +349,30 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
return jinja_template
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template:
if name is None:
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder
else:
template = TEMPLATES.get(name, None)
template = TEMPLATES.get(data_args.template, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
raise ValueError("Template {} does not exist.".format(data_args.template))
if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format))
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words
if template.replace_eos:
@@ -549,6 +567,15 @@ _register_template(
)
_register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -578,6 +605,7 @@ _register_template(
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -585,14 +613,14 @@ _register_template(
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
"developed by DeepSeek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
"and other non-computer science questions, you will refuse to answer.\n"
),
)
@@ -714,6 +742,17 @@ _register_template(
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@@ -758,6 +797,19 @@ _register_template(
)
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@@ -781,6 +833,33 @@ _register_template(
)
_register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
_register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
@@ -878,6 +957,7 @@ _register_template(
),
stop_words=["###"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)

View File

@@ -15,9 +15,12 @@
import json
import re
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
)
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
@dataclass
class ToolUtils(ABC):
@staticmethod
@abstractmethod
def get_function_slots() -> SLOTS: ...
"""
Base class for tool utilities.
"""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
def get_function_slots() -> SLOTS:
r"""
Gets a list of slots corresponding to a single function call.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
"""
...
class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"]
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content:
return content
@@ -138,3 +166,17 @@ class GLM4ToolUtils(ToolUtils):
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
return tool_utils

View File

@@ -39,7 +39,7 @@
import json
import os
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np
import torch
@@ -54,18 +54,22 @@ from ..model import load_model, load_tokenizer
from .template import get_eval_template
if TYPE_CHECKING:
from numpy.typing import NDArray
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
@@ -132,7 +136,7 @@ class Evaluator:
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join(
[
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))

View File

@@ -47,6 +47,8 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>"
LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
@@ -93,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>"
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
@@ -531,6 +535,10 @@ register_model_group(
"Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
"Gemma-2-2B": {
DownloadSource.DEFAULT: "google/gemma-2-2b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
},
"Gemma-2-9B": {
DownloadSource.DEFAULT: "google/gemma-2-9b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
@@ -539,6 +547,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
},
"Gemma-2-2B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
},
"Gemma-2-9B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
@@ -619,10 +631,22 @@ register_model_group(
register_model_group(
models={
"InternLM2.5-1.8B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
},
"InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
},
"InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
},
"InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
},
"InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
@@ -631,6 +655,10 @@ register_model_group(
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
},
"InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
},
},
template="intern2",
)
@@ -739,6 +767,37 @@ register_model_group(
)
register_model_group(
models={
"LLaMA3.1-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
},
"LLaMA3.1-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
},
"LLaMA3.1-405B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
},
"LLaMA3.1-8B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
},
"LLaMA3.1-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
},
"LLaMA3.1-405B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
},
},
template="llama3",
)
register_model_group(
models={
"LLaVA1.5-7B-Chat": {
@@ -748,7 +807,7 @@ register_model_group(
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
template="llava",
vision=True,
)
@@ -768,6 +827,17 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
},
},
template="cpm3",
)
register_model_group(
models={
"Mistral-7B-v0.1": {
@@ -791,6 +861,11 @@ register_model_group(
},
"Mistral-7B-v0.3-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-7B-Instruct-v0.3",
},
"Mistral-Nemo-Chat": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
},
},
template="mistral",
@@ -888,27 +963,28 @@ register_model_group(
register_model_group(
models={
"PaliGemma-3B-pt-224": {
"PaliGemma-3B-pt-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
},
"PaliGemma-3B-pt-448": {
"PaliGemma-3B-pt-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
},
"PaliGemma-3B-pt-896": {
"PaliGemma-3B-pt-896-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
},
"PaliGemma-3B-mix-224": {
"PaliGemma-3B-mix-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
},
"PaliGemma-3B-mix-448": {
"PaliGemma-3B-mix-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
},
},
template="paligemma",
vision=True,
)
@@ -1202,6 +1278,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
},
"Qwen2-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B",
},
"Qwen2-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B",
},
"Qwen2-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B",
},
"Qwen2-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
@@ -1222,6 +1310,18 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
},
"Qwen2-Math-1.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B-Instruct",
},
"Qwen2-Math-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B-Instruct",
},
"Qwen2-Math-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B-Instruct",
},
"Qwen2-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
@@ -1263,6 +1363,38 @@ register_model_group(
)
register_model_group(
models={
"Qwen2VL-2B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
},
"Qwen2VL-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
},
"Qwen2VL-2B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
},
"Qwen2VL-2B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
},
"Qwen2VL-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
},
"Qwen2VL-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
},
},
template="qwen2_vl",
vision=True,
)
register_model_group(
models={
"SOLAR-10.7B": {
@@ -1534,6 +1666,22 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
},
"Yi-Coder-1.5B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B",
},
"Yi-Coder-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B",
},
"Yi-Coder-1.5B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B-Chat",
},
"Yi-Coder-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B-Chat",
},
},
template="yi",
)

View File

@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.8.3"
VERSION = "0.9.0"
def print_env() -> None:

View File

@@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,14 +18,21 @@
import logging
import os
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from .constants import RUNNING_LOG
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
Redirects the logging output to the logging file for LLaMA Board.
"""
def __init__(self, output_dir: str) -> None:
@@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler):
return super().close()
def get_logger(name: str) -> logging.Logger:
def _get_default_logging_level() -> "logging._Level":
r"""
Gets a standard logger with a stream hander to stdout.
Returns the default logging level.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError("Unknown logging level: {}.".format(env_level_str))
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
return _default_log_level
def reset_logging() -> None:
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
r"""
Removes basic config of root logger. (unused in script)
Configures root logger using a stdout stream handler with an explicit format.
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
global _default_handler
with _thread_lock:
if _default_handler:
return
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "logging.Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)

View File

@@ -37,7 +37,7 @@ from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
_is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())
except Exception:
_is_bf16_available = False
@@ -79,11 +79,11 @@ def check_dependencies() -> None:
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
@@ -137,7 +137,9 @@ def get_device_count() -> int:
r"""
Gets the number of available GPU or NPU devices.
"""
if is_torch_npu_available():
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
@@ -154,6 +156,18 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
return 0, 0
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
@@ -181,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@@ -192,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports

View File

@@ -38,6 +38,10 @@ def _get_package_version(name: str) -> "Version":
return version.parse("0.0.0")
def is_pyav_available():
return _is_package_available("av")
def is_fastapi_available():
return _is_package_available("fastapi")
@@ -70,19 +74,14 @@ def is_starlette_available():
return _is_package_available("sse_starlette")
@lru_cache
def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0")
def is_uvicorn_available():
return _is_package_available("uvicorn")
def is_vllm_available():
return _is_package_available("vllm")
@lru_cache
def is_vllm_version_greater_than_0_5():
return _get_package_version("vllm") >= version.parse("0.5.0")
@lru_cache
def is_vllm_version_greater_than_0_5_1():
return _get_package_version("vllm") >= version.parse("0.5.1")

View File

@@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""

View File

@@ -73,6 +73,10 @@ class DataArguments:
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."},
)
preprocessing_batch_size: int = field(
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
@@ -141,3 +145,6 @@ class DataArguments:
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")

View File

@@ -326,6 +326,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
use_adam_mini: bool = field(
default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},

View File

@@ -15,23 +15,141 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
from typing_extensions import Self
if TYPE_CHECKING:
import torch
@dataclass
class QuantizationArguments:
r"""
Arguments pertaining to the quantization method.
"""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
@dataclass
class ModelArguments:
class ProcessorArguments:
r"""
Arguments pertaining to the image processor.
"""
image_resolution: int = field(
default=512,
metadata={"help": "Keeps the height or width of image below this resolution."},
)
video_resolution: int = field(
default=128,
metadata={"help": "Keeps the height or width of video below this resolution."},
)
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video inputs."},
)
video_maxlen: int = field(
default=64,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
@dataclass
class ExportArguments:
r"""
Arguments pertaining to the model export.
"""
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
@dataclass
class VllmArguments:
r"""
Arguments pertaining to the vLLM worker.
"""
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
@dataclass
class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments):
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
"""
model_name_or_path: str = field(
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
@@ -77,26 +195,6 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
@@ -117,9 +215,13 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
visual_inputs: bool = field(
use_unsloth_gc: bool = field(
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
)
enable_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
@@ -145,22 +247,6 @@ class ModelArguments:
default="huggingface",
metadata={"help": "Backend engine used at inference."},
)
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
@@ -181,59 +267,38 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
print_param_status: bool = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, Dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
)
model_max_length: Optional[int] = field(
default=None,
init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
)
block_diag_attn: bool = field(
default=False,
init=False,
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
)
def __post_init__(self):
self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None
self.block_diag_attn: bool = False
if self.model_name_or_path is None:
raise ValueError("Please provide `model_name_or_path`.")
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
@@ -247,9 +312,13 @@ class ModelArguments:
return asdict(self)
@classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map

View File

@@ -26,7 +26,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES
@@ -116,11 +116,14 @@ def _check_extra_dependencies(
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3")
require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
@@ -128,6 +131,9 @@ def _check_extra_dependencies(
if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
@@ -163,11 +169,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage != "sft":
if training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage != "sft" and data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.")
if data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.")
if data_args.train_on_prompt or data_args.mask_history:
raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
@@ -175,21 +185,18 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage == "ppo":
if not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if (
finetuning_args.stage == "ppo"
and training_args.report_to
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
@@ -208,11 +215,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
raise ValueError("Please specify dataset for evaluation.")
if training_args.predict_with_generate and data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if training_args.predict_with_generate and finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
@@ -221,7 +232,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
raise ValueError("This device does not support `pure_bf16`.")
if is_deepspeed_zero3_enabled():
@@ -246,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
@@ -377,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp
from llamafactory.train.tuner import run_exp # use absolute import
def launch():

View File

@@ -24,6 +24,7 @@ from ..extras.logging import get_logger
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import get_forbidden_modules, patch_target_modules
if TYPE_CHECKING:
@@ -37,7 +38,6 @@ logger = get_logger(__name__)
def _setup_full_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
@@ -46,13 +46,7 @@ def _setup_full_tuning(
return
logger.info("Fine-tuning method: Full")
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
@@ -63,7 +57,6 @@ def _setup_full_tuning(
def _setup_freeze_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
@@ -72,8 +65,8 @@ def _setup_freeze_tuning(
return
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config")
else:
config = model.config
@@ -130,10 +123,7 @@ def _setup_freeze_tuning(
trainable_layers.append(module_name)
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
@@ -211,8 +201,7 @@ def _setup_lora_tuning(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
if (
finetuning_args.use_dora
@@ -303,9 +292,9 @@ def init_adapter(
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32

View File

@@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
@@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@@ -93,17 +95,24 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer)
if model_args.visual_inputs:
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
except Exception:
raise ValueError(
"This multimodal LLM is not supported.\n"
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
)
else:
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
if getattr(config, "model_type", None) == "qwen2_vl":
setattr(processor, "video_factor", 2)
else:
setattr(processor, "video_factor", 1)
except Exception:
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if "Processor" not in processor.__class__.__name__:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
@@ -145,12 +154,16 @@ def load_model(
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
elif model_args.train_from_scratch:
model = AutoModelForCausalLM.from_config(config)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
else:
model = load_class.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)

View File

@@ -36,7 +36,7 @@ def configure_attn_implementation(
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:

View File

@@ -1,8 +1,10 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and PEFT library.
# This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
# and the Unsloth library.
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,9 +19,9 @@
# limitations under the License.
import inspect
from functools import partial
from functools import partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch
@@ -36,8 +38,70 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(
ctx: "torch.autograd.Function",
forward_function: "torch.Module",
hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any],
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
return UnslothGradientCheckpointing.apply
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None:
r"""
Activates gradient checkpointing for the current model.
@@ -52,24 +116,18 @@ def _gradient_checkpointing_enable(
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if use_unsloth_gc:
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
else:
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def _fp32_forward_post_hook(
@@ -97,7 +155,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
gradient_checkpointing_enable = partial(
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")

View File

@@ -0,0 +1,53 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.enable_liger_kernel:
return
model_type = getattr(config, "model_type", None)
if model_type == "gemma":
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif model_type == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
else:
logger.warning("Current model does not support liger kernel.")
return
apply_liger_kernel()
logger.info("Liger kernel has been applied to the model.")

View File

@@ -35,6 +35,7 @@ from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING:
@@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__)
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
@@ -68,7 +70,11 @@ def llama_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -130,14 +136,15 @@ def llama_attention_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
@@ -151,7 +158,11 @@ def llama_flash_attention_2_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -198,9 +209,24 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if is_transformers_version_greater_than_4_43():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_states.size(1),
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
@@ -225,14 +251,15 @@ def llama_flash_attention_2_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@@ -258,7 +285,11 @@ def llama_sdpa_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -322,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -28,17 +28,22 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
r"""
Finds all available modules to apply lora or galore.
"""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
if model.config.model_type == "chatglm":
if model_type == "chatglm":
forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2":
elif model_type == "internlm2":
forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]:
elif model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")
if freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
module_names = set()
for name, module in model.named_modules():

View File

@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled():
return
if getattr(model.config, "model_type", None) == "dbrx":
model_type = getattr(model.config, "model_type", None)
if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
_set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba":
if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
_set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe":
if model_type == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral":
if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
if model_type == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
if model_type in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "jetmoe":
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)

View File

@@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING:
@@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch
import transformers.models
@@ -28,7 +28,7 @@ from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@@ -80,24 +80,98 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
self.act = ACT2FN[projector_hidden_act]
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "paligemma"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else:
return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
r"""
Patches VLMs before loading them.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava": # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "paligemma"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
elif model_type == "qwen2_vl":
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("visual")
if finetuning_args.train_mm_proj_only:
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif model_type == "qwen2_vl": # variable length
image_seqlen = -1
return image_seqlen
def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]:
r"""
Freezes vision tower for VLM LoRA tuning.
"""
model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "paligemma"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
else:
if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules

View File

@@ -21,13 +21,13 @@ from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
from .model_utils.liger_kernel import configure_liger_kernel
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
@@ -71,6 +71,7 @@ def patch_config(
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_liger_kernel(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
@@ -89,9 +90,6 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "chatglm":
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
@@ -133,11 +131,9 @@ def patch_model(
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable:
prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if not model_args.use_unsloth:

View File

@@ -32,9 +32,11 @@ from transformers.utils import (
WEIGHTS_NAME,
is_safetensors_available,
)
from typing_extensions import override
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger
from ..extras.misc import get_peak_memory
if is_safetensors_available():
@@ -73,8 +75,8 @@ def fix_valuehead_checkpoint(
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
os.remove(path_to_checkpoint)
decoder_state_dict, v_head_state_dict = {}, {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
@@ -90,43 +92,52 @@ def fix_valuehead_checkpoint(
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
os.remove(path_to_checkpoint)
logger.info("Value head model saved at: {}".format(output_dir))
class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
)
class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""
def __init__(self, processor: "ProcessorMixin") -> None:
r"""
Initializes a callback for saving the processor.
"""
self.processor = processor
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback):
r"""
Initializes a callback for converting the PiSSA adapter to a normal one.
A callback for converting the PiSSA adapter to a normal one.
"""
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
@@ -141,10 +152,8 @@ class PissaConvertCallback(TrainerCallback):
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
@@ -162,29 +171,32 @@ class PissaConvertCallback(TrainerCallback):
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
)
) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default")
model.delete_adapter("pissa_init")
if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""
def __init__(self) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
# Progress
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
# Status
self.aborted = False
self.do_train = False
""" Web UI """
# Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
@@ -224,10 +236,8 @@ class LogCallback(TrainerCallback):
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
@override
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
@@ -236,55 +246,41 @@ class LogCallback(TrainerCallback):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()
@override
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@override
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train:
self._close_thread_pool()
@override
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train:
self._close_thread_pool()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save:
return
@@ -302,26 +298,31 @@ class LogCallback(TrainerCallback):
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
)
if state.num_input_tokens_seen:
logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
logs["total_tokens"] = state.num_input_tokens_seen
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
)
)
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
@override
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
if self.do_train:
return

View File

@@ -26,10 +26,11 @@ import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
return losses, chosen_rewards, rejected_rewards
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -176,7 +180,6 @@ class CustomDPOTrainer(DPOTrainer):
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
@@ -187,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_length, _ = valid_length.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
@@ -208,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
return reference_chosen_logps, reference_rejected_logps
@override
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",

View File

@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
@@ -41,13 +41,15 @@ def run_dpo(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer,
template=template,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
)
# Create reference model
@@ -60,7 +62,7 @@ def run_dpo(
ref_model = None
# Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer
trainer = CustomDPOTrainer(

View File

@@ -25,10 +25,11 @@ import torch
from transformers import Trainer
from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
return Trainer._get_train_sampler(self)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
@@ -127,17 +132,20 @@ class CustomKTOTrainer(KTOTrainer):
"input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)],
}
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -153,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -173,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
@override
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",

View File

@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional
from ...data import KTODataCollatorWithPadding, get_dataset
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
@@ -41,13 +41,15 @@ def run_kto(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="kto", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = KTODataCollatorWithPadding(
tokenizer=tokenizer,
template=template,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
)
# Create reference model
@@ -57,7 +59,7 @@ def run_kto(
ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer
trainer = CustomKTOTrainer(

View File

@@ -35,11 +35,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -133,6 +134,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
ref_model=ref_model,
tokenizer=tokenizer,
dataset=train_dataset,
optimizer=optimizer,
data_collator=data_collator,
lr_scheduler=scheduler,
)
@@ -297,13 +299,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.callback_handler.on_train_end(self.args, self.state, self.control)
@override
def create_optimizer(
self,
model: "AutoModelForCausalLMWithValueHead",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
optimizer = create_custom_optimizer(model, training_args, finetuning_args)
if optimizer is None:
decay_params, nodecay_params = [], []
decay_param_names = self.get_decay_parameter_names(model)
@@ -323,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return optimizer
@override
def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler":
@@ -409,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.float().detach() # use fp32 type
@override
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
@@ -477,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
torch.cat(all_masks)[:, :-1],
)
@override
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.

View File

@@ -17,9 +17,7 @@
from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding
from ...data import get_dataset
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint
@@ -43,11 +41,12 @@ def run_ppo(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)

View File

@@ -16,10 +16,11 @@ from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":

Some files were not shown because too many files have changed in this diff Show More